Harshasnade commited on
Commit
0966609
·
1 Parent(s): 3dccfa9

Deploy Backend (No Frontend)

Browse files
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .gitignore
3
+ __pycache__
4
+ *.pyc
5
+ .DS_Store
6
+ model/test_images
7
+ venv
8
+ env
9
+ node_modules
10
+ frontend/tests
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies for OpenCV and GLib
6
+ RUN apt-get update && apt-get install -y \
7
+ libgl1-mesa-glx \
8
+ libglib2.0-0 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first to leverage Docker cache
12
+ COPY backend/requirements_web.txt .
13
+ RUN pip install --no-cache-dir -r requirements_web.txt
14
+
15
+ # Copy the rest of the application
16
+ # Copy the rest of the application
17
+ COPY backend/ backend/
18
+ COPY model/ model/
19
+
20
+ # Set working directory to backend
21
+ WORKDIR /app/backend
22
+
23
+ # Create necessary directories
24
+ RUN mkdir -p uploads
25
+ RUN mkdir -p history_uploads
26
+
27
+ # Expose Hugging Face default port
28
+ EXPOSE 7860
29
+
30
+ # Run the application
31
+ CMD ["python", "app.py"]
backend/.DS_Store ADDED
Binary file (6.15 kB). View file
 
backend/app.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_from_directory
2
+ from flask_cors import CORS
3
+ import sys
4
+ import os
5
+
6
+ # Add model directory to path
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'model')))
8
+ import datetime
9
+ import torch
10
+ import cv2
11
+ import os
12
+ import numpy as np
13
+ import ssl
14
+ import base64
15
+ from werkzeug.utils import secure_filename
16
+ import io
17
+ from PIL import Image
18
+ from src import video_inference
19
+
20
+ # Disable SSL verification
21
+ ssl._create_default_https_context = ssl._create_unverified_context
22
+ import albumentations as A
23
+ from albumentations.pytorch import ToTensorV2
24
+ from albumentations.pytorch import ToTensorV2
25
+ from src.models import DeepfakeDetector
26
+ from src.config import Config
27
+ import database
28
+
29
+ try:
30
+ from safetensors.torch import load_file
31
+ SAFETENSORS_AVAILABLE = True
32
+ except ImportError:
33
+ SAFETENSORS_AVAILABLE = False
34
+
35
+ app = Flask(__name__, static_folder='../frontend', static_url_path='')
36
+ CORS(app)
37
+
38
+ # Configuration
39
+ UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), 'uploads')
40
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp', 'mp4', 'avi', 'mov', 'webm'}
41
+ HISTORY_FOLDER = os.path.join(os.path.dirname(__file__), 'history_uploads')
42
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
43
+ os.makedirs(HISTORY_FOLDER, exist_ok=True)
44
+
45
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
46
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
47
+
48
+ # Global model and transform
49
+ device = torch.device(Config.DEVICE)
50
+ model = None
51
+ transform = None
52
+
53
+ def get_transform():
54
+ return A.Compose([
55
+ A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE),
56
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
57
+ ToTensorV2(),
58
+ ])
59
+
60
+ def load_model():
61
+ """Load the trained deepfake detection model"""
62
+ global model, transform
63
+
64
+ checkpoint_dir = Config.CHECKPOINT_DIR
65
+ # Explicitly target the model requested by the user
66
+ target_model_name = "best_model.safetensors"
67
+ checkpoint_path = os.path.join(checkpoint_dir, target_model_name)
68
+
69
+ print(f"Using device: {device}")
70
+
71
+ # Initialize with pretrained=True to ensure missing keys (frozen layers) have valid ImageNet weights
72
+ # instead of random noise. This fixes the "random prediction" issue when the checkpoint
73
+ # only contains finetuned layers.
74
+ model = DeepfakeDetector(pretrained=True)
75
+ model.to(device)
76
+ model.eval()
77
+
78
+ # Check if file exists first
79
+ if not os.path.exists(checkpoint_path):
80
+ print(f"❌ CRITICAL ERROR: Model file not found at: {checkpoint_path}")
81
+ print(f"Please ensure '{target_model_name}' exists in '{checkpoint_dir}'")
82
+ model = None
83
+ transform = get_transform()
84
+ return model, transform
85
+
86
+ try:
87
+ print(f"Loading checkpoint: {checkpoint_path}")
88
+ if checkpoint_path.endswith(".safetensors") and SAFETENSORS_AVAILABLE:
89
+ state_dict = load_file(checkpoint_path)
90
+ else:
91
+ state_dict = torch.load(checkpoint_path, map_location=device)
92
+
93
+ # Use strict=False because the checkpoint might be a partial save (e.g. only finetuned layers)
94
+ # or there might be minor architecture mismatches.
95
+ # Since we use pretrained=True, the missing keys will remain as ImageNet weights (valid features).
96
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
97
+
98
+ print(f"✅ Model loaded successfully!")
99
+ if missing_keys:
100
+ print(f"ℹ️ {len(missing_keys)} keys missing from checkpoint (using pretrained defaults).")
101
+ if unexpected_keys:
102
+ print(f"ℹ️ {len(unexpected_keys)} unexpected keys in checkpoint.")
103
+
104
+ except Exception as e:
105
+ print(f"❌ Error loading checkpoint: {e}")
106
+ print("Predictions will fail until this is resolved.")
107
+ model = None
108
+
109
+ transform = get_transform()
110
+ return model, transform
111
+
112
+ def allowed_file(filename):
113
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
114
+
115
+ def predict_image(image_path):
116
+ """Make prediction on a single image"""
117
+ if model is None:
118
+ return None, "Error: Model not loaded. Check backend logs for 'best_finetuned_datasetB.safetensors' error."
119
+
120
+ try:
121
+ # Read and preprocess image
122
+ image = cv2.imread(image_path)
123
+ if image is None:
124
+ return None, "Error: Could not read image"
125
+
126
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
127
+ augmented = transform(image=image)
128
+ image_tensor = augmented['image'].unsqueeze(0).to(device)
129
+
130
+ # Make prediction
131
+ logits = model(image_tensor)
132
+ prob = torch.sigmoid(logits).item()
133
+
134
+ # Generate Heatmap
135
+ heatmap = model.get_heatmap(image_tensor)
136
+
137
+ # Process Heatmap for Visualization
138
+ # Resize to original image size
139
+ heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
140
+ heatmap = np.uint8(255 * heatmap)
141
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
142
+
143
+ # Superimpose
144
+ # Heatmap is BGR (from cv2), Image is RGB. Convert Image to BGR.
145
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
146
+ superimposed_img = heatmap * 0.4 + image_bgr * 0.6
147
+ superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
148
+
149
+ # Encode to Base64
150
+ _, buffer = cv2.imencode('.jpg', superimposed_img)
151
+ heatmap_b64 = base64.b64encode(buffer).decode('utf-8')
152
+
153
+ is_fake = prob > 0.5
154
+ label = "FAKE" if is_fake else "REAL"
155
+ confidence = prob if is_fake else 1 - prob
156
+
157
+ return {
158
+ 'prediction': label,
159
+ 'confidence': float(confidence),
160
+ 'fake_probability': float(prob),
161
+ 'real_probability': float(1 - prob),
162
+ 'heatmap': heatmap_b64
163
+ }, None
164
+ except Exception as e:
165
+ return None, str(e)
166
+
167
+
168
+ @app.route('/')
169
+ def index():
170
+ """Backend Root"""
171
+ return jsonify({
172
+ "status": "online",
173
+ "message": "DeepGuard Backend is Running",
174
+ "endpoints": ["/api/predict", "/api/history", "/api/health"]
175
+ })
176
+
177
+ @app.route('/history_uploads/<path:filename>')
178
+ def serve_history_image(filename):
179
+ """Serve history images"""
180
+ return send_from_directory(HISTORY_FOLDER, filename)
181
+
182
+ @app.route('/api/health', methods=['GET'])
183
+ def health_check():
184
+ """Health check endpoint"""
185
+ return jsonify({
186
+ 'status': 'healthy',
187
+ 'model_loaded': model is not None,
188
+ 'device': str(device)
189
+ })
190
+
191
+ @app.route('/api/predict', methods=['POST'])
192
+ def predict():
193
+ """Handle image upload and prediction"""
194
+ try:
195
+ # Check if file is present
196
+ if 'file' not in request.files:
197
+ return jsonify({'error': 'No file provided'}), 400
198
+
199
+ file = request.files['file']
200
+
201
+ if file.filename == '':
202
+ return jsonify({'error': 'No file selected'}), 400
203
+
204
+ if not allowed_file(file.filename):
205
+ return jsonify({'error': 'Invalid file type. Allowed types: png, jpg, jpeg, webp'}), 400
206
+
207
+ # Save file
208
+ filename = secure_filename(file.filename)
209
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
210
+ file.save(filepath)
211
+
212
+ # Make prediction
213
+ result, error = predict_image(filepath)
214
+
215
+ # Save to History
216
+ import shutil
217
+ history_filename = f"scan_{int(datetime.datetime.now().timestamp())}_{filename}"
218
+ history_path = os.path.join(HISTORY_FOLDER, history_filename)
219
+
220
+ # Copy original file to history folder
221
+ # We need to read the file again or just copy if we haven't deleted it?
222
+ # We read via cv2, the file is still at filepath.
223
+ shutil.copy(filepath, history_path)
224
+
225
+ # Relative path for frontend
226
+ relative_path = f"history_uploads/{history_filename}"
227
+
228
+ database.add_scan(
229
+ filename=filename,
230
+ prediction=result['prediction'],
231
+ confidence=result['confidence'],
232
+ fake_prob=result['fake_probability'],
233
+ real_prob=result['real_probability'],
234
+ image_path=relative_path
235
+ )
236
+
237
+ # Clean up uploaded file
238
+ try:
239
+ os.remove(filepath)
240
+ except:
241
+ pass
242
+
243
+ return jsonify(result)
244
+
245
+ except Exception as e:
246
+ return jsonify({'error': str(e)}), 500
247
+
248
+ @app.route('/api/predict_video', methods=['POST'])
249
+ def predict_video():
250
+ """Handle video upload and prediction"""
251
+ try:
252
+ if 'file' not in request.files:
253
+ return jsonify({'error': 'No file provided'}), 400
254
+
255
+ file = request.files['file']
256
+
257
+ if file.filename == '':
258
+ return jsonify({'error': 'No file selected'}), 400
259
+
260
+ if not allowed_file(file.filename):
261
+ return jsonify({'error': 'Invalid file type'}), 400
262
+
263
+ # Save file
264
+ filename = secure_filename(file.filename)
265
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
266
+ file.save(filepath)
267
+
268
+ # Process Video
269
+ # Note: process_video needs sys.path to be correct to import models inside it if it was standalone,
270
+ # but here we pass the already loaded 'model' object.
271
+ if model is None:
272
+ return jsonify({'error': 'Model not loaded'}), 500
273
+
274
+ result = video_inference.process_video(filepath, model, transform, device)
275
+
276
+ if "error" in result:
277
+ return jsonify(result), 500
278
+
279
+ # Save to History (Using the first frame or a placeholder icon for now?)
280
+ # For video, we might want to save the video file itself to history_uploads
281
+ # or just a thumbnail. Let's save the video for now.
282
+ import shutil
283
+ history_filename = f"scan_{int(datetime.datetime.now().timestamp())}_{filename}"
284
+ history_path = os.path.join(HISTORY_FOLDER, history_filename)
285
+ shutil.copy(filepath, history_path)
286
+
287
+ relative_path = f"history_uploads/{history_filename}"
288
+
289
+ # Add to database
290
+ # Note: The database 'add_scan' might expect image-specific fields.
291
+ # We'll re-use 'fake_prob' as 'avg_fake_prob'
292
+ database.add_scan(
293
+ filename=filename,
294
+ prediction=result['prediction'],
295
+ confidence=result['confidence'],
296
+ fake_prob=result['avg_fake_prob'],
297
+ real_prob=1 - result['avg_fake_prob'],
298
+ image_path=relative_path
299
+ )
300
+
301
+ # Clean up
302
+ try:
303
+ os.remove(filepath)
304
+ except:
305
+ pass
306
+
307
+ # Add video URL for frontend playback
308
+ result['video_url'] = relative_path
309
+
310
+ return jsonify(result)
311
+
312
+ except Exception as e:
313
+ print(f"Video Error: {e}")
314
+ return jsonify({'error': str(e)}), 500
315
+
316
+
317
+ @app.route('/api/history', methods=['GET'])
318
+ def get_history():
319
+ """Get all past scans"""
320
+ history = database.get_history()
321
+ history = database.get_history()
322
+ return jsonify(history)
323
+
324
+ @app.route('/api/history/<int:scan_id>', methods=['DELETE'])
325
+ def delete_scan(scan_id):
326
+ """Delete a specific scan"""
327
+ if database.delete_scan(scan_id):
328
+ return jsonify({'message': 'Scan deleted'})
329
+ return jsonify({'error': 'Failed to delete scan'}), 500
330
+
331
+ @app.route('/api/history', methods=['DELETE'])
332
+ def clear_history():
333
+ """Clear all history"""
334
+ if database.clear_history():
335
+ return jsonify({'message': 'History cleared'})
336
+ return jsonify({'error': 'Failed to clear history'}), 500
337
+
338
+ @app.route('/api/model-info', methods=['GET'])
339
+ def model_info():
340
+ """Return model information"""
341
+ return jsonify({
342
+ 'model_name': 'DeepGuard: Advanced Deepfake Detector',
343
+ 'architecture': 'Hybrid CNN-ViT',
344
+ 'components': {
345
+ 'RGB Analysis': Config.USE_RGB,
346
+ 'Frequency Domain': Config.USE_FREQ,
347
+ 'Patch-based Detection': Config.USE_PATCH,
348
+ 'Vision Transformer': Config.USE_VIT
349
+ },
350
+ 'image_size': Config.IMAGE_SIZE,
351
+ 'device': str(device),
352
+ 'threshold': 0.5
353
+ })
354
+
355
+ if __name__ == '__main__':
356
+ print("=" * 60)
357
+ print("🚀 DeepGuard - Deepfake Detection System")
358
+ print("=" * 60)
359
+
360
+ # Load model
361
+ load_model()
362
+
363
+ print("=" * 60)
364
+ port = int(os.environ.get("PORT", 7860))
365
+ print(f"🌐 Starting server on http://0.0.0.0:{port}")
366
+ print("=" * 60)
367
+
368
+ app.run(debug=False, host='0.0.0.0', port=port)
backend/database.db ADDED
Binary file (20.5 kB). View file
 
backend/database.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import datetime
3
+ import os
4
+
5
+ DB_NAME = os.path.join(os.path.dirname(__file__), 'database.db')
6
+
7
+ def get_db_connection():
8
+ try:
9
+ conn = sqlite3.connect(DB_NAME)
10
+ conn.row_factory = sqlite3.Row
11
+ return conn
12
+ except sqlite3.Error as e:
13
+ print(f"Database error: {e}")
14
+ return None
15
+
16
+ def init_db():
17
+ conn = get_db_connection()
18
+ if conn:
19
+ try:
20
+ conn.execute('''
21
+ CREATE TABLE IF NOT EXISTS history (
22
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
23
+ filename TEXT NOT NULL,
24
+ prediction TEXT NOT NULL,
25
+ confidence REAL NOT NULL,
26
+ fake_probability REAL NOT NULL,
27
+ real_probability REAL NOT NULL,
28
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
29
+ )
30
+ ''')
31
+ conn.commit()
32
+ print("✅ Database initialized successfully.")
33
+ except sqlite3.Error as e:
34
+ print(f"Error initializing database: {e}")
35
+
36
+ # Migration: Add image_path if not exists
37
+ try:
38
+ conn.execute('ALTER TABLE history ADD COLUMN image_path TEXT')
39
+ print("✅ Added image_path column.")
40
+ except sqlite3.Error:
41
+ pass # Column likely exists
42
+
43
+ finally:
44
+ conn.close()
45
+
46
+ def add_scan(filename, prediction, confidence, fake_prob, real_prob, image_path=""):
47
+ conn = get_db_connection()
48
+ if conn:
49
+ try:
50
+ conn.execute('''
51
+ INSERT INTO history (filename, prediction, confidence, fake_probability, real_probability, image_path)
52
+ VALUES (?, ?, ?, ?, ?, ?)
53
+ ''', (filename, prediction, confidence, fake_prob, real_prob, image_path))
54
+ conn.commit()
55
+ return True
56
+ except sqlite3.Error as e:
57
+ print(f"Error adding scan: {e}")
58
+ return False
59
+ finally:
60
+ conn.close()
61
+ return False
62
+
63
+ def get_history():
64
+ conn = get_db_connection()
65
+ if conn:
66
+ try:
67
+ cursor = conn.execute('SELECT * FROM history ORDER BY timestamp DESC')
68
+ history = [dict(row) for row in cursor.fetchall()]
69
+ return history
70
+ except sqlite3.Error as e:
71
+ print(f"Error retrieving history: {e}")
72
+ return []
73
+ finally:
74
+ conn.close()
75
+ return []
76
+
77
+ def clear_history():
78
+ conn = get_db_connection()
79
+ if conn:
80
+ try:
81
+ conn.execute('DELETE FROM history')
82
+ conn.commit()
83
+ return True
84
+ except sqlite3.Error as e:
85
+ print(f"Error clearing history: {e}")
86
+ return False
87
+ finally:
88
+ conn.close()
89
+ return False
90
+
91
+ def delete_scan(scan_id):
92
+ conn = get_db_connection()
93
+ if conn:
94
+ try:
95
+ conn.execute('DELETE FROM history WHERE id = ?', (scan_id,))
96
+ conn.commit()
97
+ return True
98
+ except sqlite3.Error as e:
99
+ print(f"Error deleting scan: {e}")
100
+ return False
101
+ finally:
102
+ conn.close()
103
+ return False
104
+
105
+ # Initialize DB on module load
106
+ init_db()
backend/requirements_web.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ flask==3.0.0
2
+ flask-cors==4.0.0
3
+ torch
4
+ torchvision
5
+ opencv-python
6
+ albumentations
7
+ Pillow
8
+ numpy
9
+ safetensors
model/.DS_Store ADDED
Binary file (10.2 kB). View file
 
model/results/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/results/checkpoints/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/results/checkpoints/best_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a58fd840e8ebab964a3021acb9e365fe445a86dbcb8af93d808f68eb3a254ad4
3
+ size 202457588
model/src/__init__.py ADDED
File without changes
model/src/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (175 Bytes). View file
 
model/src/__pycache__/config.cpython-314.pyc ADDED
Binary file (2.58 kB). View file
 
model/src/__pycache__/dataset.cpython-314.pyc ADDED
Binary file (5.79 kB). View file
 
model/src/__pycache__/models.cpython-314.pyc ADDED
Binary file (10.3 kB). View file
 
model/src/__pycache__/train.cpython-314.pyc ADDED
Binary file (13.8 kB). View file
 
model/src/__pycache__/utils.cpython-314.pyc ADDED
Binary file (1.46 kB). View file
 
model/src/__pycache__/video_inference.cpython-314.pyc ADDED
Binary file (6.7 kB). View file
 
model/src/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import platform
4
+
5
+ class Config:
6
+ # System
7
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ DATA_DIR = os.path.join(PROJECT_ROOT, "data")
9
+ RESULTS_DIR = os.path.join(PROJECT_ROOT, "results")
10
+
11
+ # Model Architecture
12
+ IMAGE_SIZE = 256
13
+ NUM_CLASSES = 1 # Logic: 0=Real, 1=Fake (Sigmoid output)
14
+
15
+ # Component Flags
16
+ USE_RGB = True
17
+ USE_FREQ = True
18
+ USE_PATCH = True
19
+ USE_VIT = True
20
+
21
+ # Training Hyperparameters
22
+ BATCH_SIZE = 32 # Optimized for Mac M4 (Unified Memory)
23
+ EPOCHS = 3
24
+ LEARNING_RATE = 1e-4
25
+ WEIGHT_DECAY = 1e-5
26
+ NUM_WORKERS = 8 # Leverage M4 Performance Cores
27
+
28
+ # Hardware
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
30
+
31
+ # Paths
32
+ # Docker Deployment: Use relative paths
33
+ DATA_DIR = os.path.join(PROJECT_ROOT, "data")
34
+
35
+ # Since we are using the root folder, the script will recursively find ALL images
36
+ # in all sub-datasets and split them 80/20 for training/validation.
37
+ TRAIN_DATA_PATH = DATA_DIR
38
+ TEST_DATA_PATH = DATA_DIR
39
+ CHECKPOINT_DIR = os.path.join(RESULTS_DIR, "checkpoints")
40
+
41
+ @classmethod
42
+ def setup(cls):
43
+ os.makedirs(cls.RESULTS_DIR, exist_ok=True)
44
+ os.makedirs(cls.CHECKPOINT_DIR, exist_ok=True)
45
+ os.makedirs(cls.DATA_DIR, exist_ok=True)
46
+ print(f"Project initialized at {cls.PROJECT_ROOT}")
47
+ print(f"Using device: {cls.DEVICE}")
48
+
49
+ if __name__ == "__main__":
50
+ Config.setup()
model/src/dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+ from src.config import Config
9
+
10
+ class DeepfakeDataset(Dataset):
11
+ def __init__(self, root_dir=None, file_paths=None, labels=None, phase='train', max_samples=None):
12
+ """
13
+ Args:
14
+ root_dir (str): Directory with subfolders containing images. (Optional if file_paths provided)
15
+ file_paths (list): List of absolute paths to images.
16
+ labels (list): List of labels corresponding to file_paths.
17
+ phase (str): 'train' or 'val'.
18
+ max_samples (int): Optional limit for quick debugging.
19
+ """
20
+ self.phase = phase
21
+
22
+ if file_paths is not None and labels is not None:
23
+ self.image_paths = file_paths
24
+ self.labels = labels
25
+ elif root_dir is not None:
26
+ self.image_paths, self.labels = self.scan_directory(root_dir)
27
+ else:
28
+ raise ValueError("Either root_dir or (file_paths, labels) must be provided.")
29
+
30
+ if max_samples:
31
+ self.image_paths = self.image_paths[:max_samples]
32
+ self.labels = self.labels[:max_samples]
33
+
34
+ self.transform = self._get_transforms()
35
+
36
+ print(f"Initialized {self.phase} dataset with {len(self.image_paths)} samples.")
37
+
38
+ @staticmethod
39
+ def scan_directory(root_dir):
40
+ image_paths = []
41
+ labels = []
42
+ print(f"Scanning dataset at {root_dir}...")
43
+
44
+ # Valid extensions
45
+ exts = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif')
46
+
47
+ for root, dirs, files in os.walk(root_dir):
48
+ for file in files:
49
+ if file.lower().endswith(exts):
50
+ path = os.path.join(root, file)
51
+ # Label inference based on full path
52
+ path_lower = path.lower()
53
+
54
+ label = None
55
+ # Prioritize explicit folder names
56
+ if "real" in path_lower:
57
+ label = 0.0
58
+ elif any(x in path_lower for x in ["fake", "df", "synthesis", "generated", "ai"]):
59
+ label = 1.0
60
+
61
+ if label is not None:
62
+ image_paths.append(path)
63
+ labels.append(label)
64
+
65
+ return image_paths, labels
66
+
67
+ def _get_transforms(self):
68
+ size = Config.IMAGE_SIZE
69
+ if self.phase == 'train':
70
+ return A.Compose([
71
+ A.Resize(size, size),
72
+ A.HorizontalFlip(p=0.5),
73
+ A.RandomBrightnessContrast(p=0.2),
74
+ A.GaussNoise(p=0.2),
75
+ # A.GaussianBlur(p=0.1),
76
+ # Fixed for newer albumentations versions
77
+ A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3),
78
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
79
+ ToTensorV2(),
80
+ ])
81
+ else:
82
+ return A.Compose([
83
+ A.Resize(size, size),
84
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
85
+ ToTensorV2(),
86
+ ])
87
+
88
+ def __len__(self):
89
+ return len(self.image_paths)
90
+
91
+ def __getitem__(self, idx):
92
+ path = self.image_paths[idx]
93
+ label = self.labels[idx]
94
+
95
+ try:
96
+ image = cv2.imread(path)
97
+ if image is None:
98
+ raise ValueError("Image not found or corrupt")
99
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
100
+ except Exception as e:
101
+ # print(f"Error loading {path}: {e}")
102
+ # Fallback to next image
103
+ return self.__getitem__((idx + 1) % len(self))
104
+
105
+ if self.transform:
106
+ augmented = self.transform(image=image)
107
+ image = augmented['image']
108
+
109
+ return image, torch.tensor(label, dtype=torch.float32)
model/src/finetune.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ import random
8
+ import ssl
9
+ # Disable SSL verification for downloading pretrained weights
10
+ ssl._create_default_https_context = ssl._create_unverified_context
11
+
12
+ from src.config import Config
13
+ from src.models import DeepfakeDetector
14
+ from src.dataset import DeepfakeDataset
15
+
16
+ try:
17
+ from safetensors.torch import save_file, load_model
18
+ SAFETENSORS_AVAILABLE = True
19
+ except ImportError:
20
+ SAFETENSORS_AVAILABLE = False
21
+ print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
22
+
23
+ def finetune():
24
+ # Setup
25
+ Config.setup()
26
+ device = torch.device(Config.DEVICE)
27
+
28
+ # Fine-tuning dataset path
29
+ FINETUNE_DATA_PATH = "/Users/harshvardhan/Developer/dataset/Dataset c"
30
+
31
+ print(f"\n{'='*80}")
32
+ print("FINE-TUNING ON DATASET C")
33
+ print(f"{'='*80}\n")
34
+
35
+ # --- Data Loading ---
36
+ print(f"Loading data from: {FINETUNE_DATA_PATH}")
37
+ all_paths, all_labels = DeepfakeDataset.scan_directory(FINETUNE_DATA_PATH)
38
+
39
+ if len(all_paths) == 0:
40
+ print(f"No images found in {FINETUNE_DATA_PATH}")
41
+ return
42
+
43
+ # Shuffle and split
44
+ combined = list(zip(all_paths, all_labels))
45
+ random.shuffle(combined)
46
+
47
+ split_idx = int(len(combined) * 0.8)
48
+ train_data = combined[:split_idx]
49
+ val_data = combined[split_idx:]
50
+
51
+ train_paths, train_labels = zip(*train_data)
52
+ val_paths, val_labels = zip(*val_data)
53
+
54
+ train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
55
+ val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
56
+
57
+ # Dataloaders
58
+ train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
59
+ num_workers=Config.NUM_WORKERS,
60
+ pin_memory=True if device.type=='cuda' else False,
61
+ persistent_workers=True if Config.NUM_WORKERS > 0 else False)
62
+ val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
63
+ num_workers=Config.NUM_WORKERS,
64
+ pin_memory=True if device.type=='cuda' else False,
65
+ persistent_workers=True if Config.NUM_WORKERS > 0 else False)
66
+
67
+ # Load pre-trained model from Dataset A
68
+ print("\n🔄 Loading pre-trained model from Dataset A...")
69
+ model = DeepfakeDetector(pretrained=False).to(device)
70
+
71
+ checkpoint_path = "results/checkpoints/best_model.safetensors"
72
+ if os.path.exists(checkpoint_path):
73
+ load_model(model, checkpoint_path, strict=False)
74
+ print(f"✅ Loaded checkpoint: {checkpoint_path}")
75
+ else:
76
+ print("⚠️ No checkpoint found! Starting from random weights.")
77
+
78
+ model.to(device)
79
+
80
+ # Optimization with LOWER learning rate for fine-tuning
81
+ FINETUNE_LR = 1e-5 # 10x lower than original training
82
+ FINETUNE_EPOCHS = 2
83
+
84
+ print(f"\n📝 Fine-tuning settings:")
85
+ print(f" Learning Rate: {FINETUNE_LR} (10x lower for fine-tuning)")
86
+ print(f" Epochs: {FINETUNE_EPOCHS}")
87
+ print(f" Batch Size: {Config.BATCH_SIZE}")
88
+
89
+ criterion = nn.BCEWithLogitsLoss()
90
+ optimizer = optim.AdamW(model.parameters(), lr=FINETUNE_LR, weight_decay=Config.WEIGHT_DECAY)
91
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
92
+
93
+ # Loop
94
+ best_acc = 0.0
95
+
96
+ for epoch in range(FINETUNE_EPOCHS):
97
+ model.train()
98
+ train_loss = 0.0
99
+ train_correct = 0
100
+ train_total = 0
101
+
102
+ loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINETUNE_EPOCHS}")
103
+ for images, labels in loop:
104
+ images = images.to(device)
105
+ labels = labels.to(device).unsqueeze(1)
106
+
107
+ optimizer.zero_grad()
108
+ outputs = model(images)
109
+ loss = criterion(outputs, labels)
110
+ loss.backward()
111
+ optimizer.step()
112
+
113
+ train_loss += loss.item()
114
+ preds = (torch.sigmoid(outputs) > 0.5).float()
115
+ correct = (preds == labels).sum().item()
116
+ train_correct += correct
117
+ train_total += labels.size(0)
118
+
119
+ loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0))
120
+
121
+ train_acc = train_correct / train_total if train_total > 0 else 0
122
+ print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
123
+
124
+ # Save checkpoint after every epoch
125
+ save_checkpoint(model, epoch+1, train_acc, name=f"finetuned_datasetC_ep{epoch+1}")
126
+
127
+ # Validation
128
+ if len(val_dataset) > 0:
129
+ val_loss, val_acc = validate(model, val_loader, criterion, device)
130
+ print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
131
+
132
+ # Save best model if validation accuracy improved
133
+ if val_acc > best_acc:
134
+ best_acc = val_acc
135
+ print(f"⭐ New best model! Validation Accuracy: {val_acc:.4f}")
136
+ save_checkpoint(model, epoch+1, val_acc, name="best_finetuned_datasetC")
137
+
138
+ scheduler.step()
139
+
140
+ print(f"\n🎉 Fine-tuning Complete!")
141
+ print(f"Best Validation Accuracy: {best_acc:.4f}")
142
+ print(f"\n💾 Checkpoints saved in: results/checkpoints/")
143
+
144
+ def validate(model, loader, criterion, device):
145
+ model.eval()
146
+ val_loss = 0.0
147
+ correct = 0
148
+ total = 0
149
+
150
+ with torch.no_grad():
151
+ for images, labels in loader:
152
+ images = images.to(device)
153
+ labels = labels.to(device).unsqueeze(1)
154
+
155
+ outputs = model(images)
156
+ loss = criterion(outputs, labels)
157
+
158
+ val_loss += loss.item()
159
+ preds = (torch.sigmoid(outputs) > 0.5).float()
160
+ correct += (preds == labels).sum().item()
161
+ total += labels.size(0)
162
+
163
+ return val_loss / len(loader), correct / total
164
+
165
+ def save_checkpoint(model, epoch, acc, name="checkpoint"):
166
+ state_dict = model.state_dict()
167
+ filename = f"{name}.safetensors"
168
+ path = os.path.join(Config.CHECKPOINT_DIR, filename)
169
+
170
+ if SAFETENSORS_AVAILABLE:
171
+ try:
172
+ from safetensors.torch import save_model
173
+ save_model(model, path)
174
+ print(f"✅ Saved: {filename}")
175
+ except Exception as e:
176
+ print(f"SafeTensors save failed, falling back to .pth: {e}")
177
+ torch.save(state_dict, path.replace(".safetensors", ".pth"))
178
+ else:
179
+ torch.save(state_dict, path.replace(".safetensors", ".pth"))
180
+
181
+ if __name__ == "__main__":
182
+ finetune()
model/src/finetune_dataset_a.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ import random
8
+ import ssl
9
+ import platform
10
+
11
+ # Disable SSL verification for downloading pretrained weights
12
+ ssl._create_default_https_context = ssl._create_unverified_context
13
+
14
+ from src.config import Config
15
+ from src.models import DeepfakeDetector
16
+ from src.dataset import DeepfakeDataset
17
+
18
+ try:
19
+ from safetensors.torch import save_file, load_model
20
+ SAFETENSORS_AVAILABLE = True
21
+ except ImportError:
22
+ SAFETENSORS_AVAILABLE = False
23
+ print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
24
+
25
+ def finetune():
26
+ # Setup
27
+ Config.setup()
28
+ device = torch.device(Config.DEVICE)
29
+
30
+ # Fine-tuning dataset path - Dataset A
31
+ if platform.system() == "Windows":
32
+ FINETUNE_DATA_PATH = r"C:\Users\kanna\Downloads\Dataset\Dataset A\Dataset A"
33
+ else:
34
+ FINETUNE_DATA_PATH = "/Users/harshvardhan/Developer/dataset/Dataset A"
35
+
36
+ print(f"\n{'='*80}")
37
+ print("FINE-TUNING ON DATASET A")
38
+ print(f"{'='*80}\n")
39
+
40
+ # --- Data Loading ---
41
+ print(f"Loading data from: {FINETUNE_DATA_PATH}")
42
+ if not os.path.exists(FINETUNE_DATA_PATH):
43
+ print(f"❌ Error: Dataset path not found: {FINETUNE_DATA_PATH}")
44
+ return
45
+
46
+ all_paths, all_labels = DeepfakeDataset.scan_directory(FINETUNE_DATA_PATH)
47
+
48
+ if len(all_paths) == 0:
49
+ print(f"No images found in {FINETUNE_DATA_PATH}")
50
+ return
51
+
52
+ # Shuffle and split
53
+ combined = list(zip(all_paths, all_labels))
54
+ random.shuffle(combined)
55
+
56
+ # Use 80/20 split for fine-tuning dataset
57
+ split_idx = int(len(combined) * 0.8)
58
+ train_data = combined[:split_idx]
59
+ val_data = combined[split_idx:]
60
+
61
+ train_paths, train_labels = zip(*train_data)
62
+ val_paths, val_labels = zip(*val_data)
63
+
64
+ train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
65
+ val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
66
+
67
+ # Dataloaders - Use Config.BATCH_SIZE but ensure it fits GPU
68
+ train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
69
+ num_workers=Config.NUM_WORKERS,
70
+ pin_memory=True if device.type=='cuda' else False,
71
+ persistent_workers=True if Config.NUM_WORKERS > 0 else False)
72
+ val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
73
+ num_workers=Config.NUM_WORKERS,
74
+ pin_memory=True if device.type=='cuda' else False,
75
+ persistent_workers=True if Config.NUM_WORKERS > 0 else False)
76
+
77
+ # Load pre-trained model
78
+ print("\n🔄 Loading pre-trained model (best_model)...")
79
+ model = DeepfakeDetector(pretrained=False).to(device)
80
+
81
+ # Try to load the best model found so far
82
+ checkpoint_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.safetensors")
83
+ if not os.path.exists(checkpoint_path):
84
+ # Fallback to .pth if safetensors logic above failed or not used previously
85
+ checkpoint_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.pth")
86
+
87
+ if os.path.exists(checkpoint_path):
88
+ try:
89
+ if checkpoint_path.endswith(".safetensors"):
90
+ load_model(model, checkpoint_path, strict=False)
91
+ else:
92
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
93
+ print(f"✅ Loaded checkpoint: {checkpoint_path}")
94
+ except Exception as e:
95
+ print(f"⚠️ Error loading checkpoint: {e}")
96
+ print("Starting from random weights (not ideal for fine-tuning!)")
97
+ else:
98
+ print("⚠️ No checkpoint found! Starting from random weights.")
99
+
100
+ model.to(device)
101
+
102
+ # Optimization with LOWER learning rate for fine-tuning
103
+ FINETUNE_LR = 1e-5 # 10x lower than original training
104
+ FINETUNE_EPOCHS = 5 # Give it a few epochs to adapt
105
+
106
+ print(f"\n📝 Fine-tuning settings:")
107
+ print(f" Learning Rate: {FINETUNE_LR} (Low LR for fine-tuning)")
108
+ print(f" Epochs: {FINETUNE_EPOCHS}")
109
+ print(f" Batch Size: {Config.BATCH_SIZE}")
110
+
111
+ criterion = nn.BCEWithLogitsLoss()
112
+ optimizer = optim.AdamW(model.parameters(), lr=FINETUNE_LR, weight_decay=Config.WEIGHT_DECAY)
113
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
114
+
115
+ # Loop
116
+ best_acc = 0.0
117
+
118
+ for epoch in range(FINETUNE_EPOCHS):
119
+ model.train()
120
+ train_loss = 0.0
121
+ train_correct = 0
122
+ train_total = 0
123
+
124
+ loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINETUNE_EPOCHS}")
125
+ for images, labels in loop:
126
+ images = images.to(device)
127
+ labels = labels.to(device).unsqueeze(1)
128
+
129
+ optimizer.zero_grad()
130
+ outputs = model(images)
131
+ loss = criterion(outputs, labels)
132
+ loss.backward()
133
+ optimizer.step()
134
+
135
+ train_loss += loss.item()
136
+ preds = (torch.sigmoid(outputs) > 0.5).float()
137
+ correct = (preds == labels).sum().item()
138
+ train_correct += correct
139
+ train_total += labels.size(0)
140
+
141
+ loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0) if labels.size(0) > 0 else 0)
142
+
143
+ train_acc = train_correct / train_total if train_total > 0 else 0
144
+ print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
145
+
146
+ # Save checkpoint after every epoch
147
+ save_checkpoint(model, epoch+1, train_acc, name=f"finetuned_datasetA_ep{epoch+1}")
148
+
149
+ # Validation
150
+ if len(val_dataset) > 0:
151
+ val_loss, val_acc = validate(model, val_loader, criterion, device)
152
+ print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
153
+
154
+ scheduler.step(val_acc)
155
+
156
+ # Save best model if validation accuracy improved
157
+ if val_acc > best_acc:
158
+ best_acc = val_acc
159
+ print(f"⭐ New best model! Validation Accuracy: {val_acc:.4f}")
160
+ save_checkpoint(model, epoch+1, val_acc, name="best_finetuned_datasetA")
161
+
162
+ print(f"\n🎉 Fine-tuning Complete!")
163
+ print(f"Best Validation Accuracy: {best_acc:.4f}")
164
+ print(f"\n💾 Checkpoints saved in: {Config.CHECKPOINT_DIR}")
165
+
166
+ def validate(model, loader, criterion, device):
167
+ model.eval()
168
+ val_loss = 0.0
169
+ correct = 0
170
+ total = 0
171
+
172
+ with torch.no_grad():
173
+ for images, labels in loader:
174
+ images = images.to(device)
175
+ labels = labels.to(device).unsqueeze(1)
176
+
177
+ outputs = model(images)
178
+ loss = criterion(outputs, labels)
179
+
180
+ val_loss += loss.item()
181
+ preds = (torch.sigmoid(outputs) > 0.5).float()
182
+ correct += (preds == labels).sum().item()
183
+ total += labels.size(0)
184
+
185
+ return val_loss / len(loader), correct / total
186
+
187
+ def save_checkpoint(model, epoch, acc, name="checkpoint"):
188
+ state_dict = model.state_dict()
189
+ filename = f"{name}.safetensors"
190
+ path = os.path.join(Config.CHECKPOINT_DIR, filename)
191
+
192
+ if SAFETENSORS_AVAILABLE:
193
+ try:
194
+ from safetensors.torch import save_model
195
+ save_model(model, path)
196
+ print(f"✅ Saved: {filename}")
197
+ except Exception as e:
198
+ print(f"SafeTensors save failed, falling back to .pth: {e}")
199
+ torch.save(state_dict, path.replace(".safetensors", ".pth"))
200
+ else:
201
+ torch.save(state_dict, path.replace(".safetensors", ".pth"))
202
+
203
+ if __name__ == "__main__":
204
+ finetune()
model/src/inference.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import cv2
4
+ import os
5
+ import glob
6
+ import numpy as np
7
+ import ssl
8
+ # Disable SSL verification
9
+ ssl._create_default_https_context = ssl._create_unverified_context
10
+
11
+ import albumentations as A
12
+ from albumentations.pytorch import ToTensorV2
13
+ from src.models import DeepfakeDetector
14
+ from src.config import Config
15
+
16
+ try:
17
+ from safetensors.torch import load_file
18
+ SAFETENSORS_AVAILABLE = True
19
+ except ImportError:
20
+ SAFETENSORS_AVAILABLE = False
21
+
22
+ def get_transform():
23
+ return A.Compose([
24
+ A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE),
25
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
26
+ ToTensorV2(),
27
+ ])
28
+
29
+ def load_models(checkpoints_arg, device):
30
+ """
31
+ Load one or multiple models for ensemble inference.
32
+ checkpoints_arg: Comma-separated list of paths, or single path, or directory.
33
+ """
34
+ paths = []
35
+ if os.path.isdir(checkpoints_arg):
36
+ paths = glob.glob(os.path.join(checkpoints_arg, "*.safetensors"))
37
+ if not paths:
38
+ paths = glob.glob(os.path.join(checkpoints_arg, "*.pth"))
39
+ else:
40
+ paths = checkpoints_arg.split(',')
41
+
42
+ models = []
43
+ print(f"Loading {len(paths)} model(s) for ensemble inference...")
44
+
45
+ for path in paths:
46
+ path = path.strip()
47
+ if not path: continue
48
+
49
+ print(f"Loading: {path}")
50
+ model = DeepfakeDetector(pretrained=False) # Structure only
51
+ model.to(device)
52
+ model.eval()
53
+
54
+ try:
55
+ if path.endswith(".safetensors") and SAFETENSORS_AVAILABLE:
56
+ state_dict = load_file(path)
57
+ else:
58
+ state_dict = torch.load(path, map_location=device)
59
+ model.load_state_dict(state_dict)
60
+ models.append(model)
61
+ print(f"✅ Successfully loaded: {os.path.basename(path)}")
62
+ except Exception as e:
63
+ print(f"❌ Failed to load {path}: {e}")
64
+ import traceback
65
+ traceback.print_exc()
66
+
67
+ if not models:
68
+ # Fallback for testing if no checkpoint exists yet
69
+ print("Warning: No valid checkoints loaded. Using random initialization for testing flow.")
70
+ model = DeepfakeDetector(pretrained=False).to(device)
71
+ model.eval()
72
+ models.append(model)
73
+
74
+ return models
75
+
76
+ def predict_ensemble(models, image_path, device, transform):
77
+ try:
78
+ image = cv2.imread(image_path)
79
+ if image is None:
80
+ return None, "Error: Could not read image"
81
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
82
+ except Exception as e:
83
+ return None, str(e)
84
+
85
+ augmented = transform(image=image)
86
+ image_tensor = augmented['image'].unsqueeze(0).to(device)
87
+
88
+ probs = []
89
+ with torch.no_grad():
90
+ for model in models:
91
+ logits = model(image_tensor)
92
+ prob = torch.sigmoid(logits).item()
93
+ probs.append(prob)
94
+
95
+ # Ensemble Strategy: Average Probability
96
+ avg_prob = sum(probs) / len(probs)
97
+ return avg_prob, None
98
+
99
+ def main():
100
+ parser = argparse.ArgumentParser(description="Deepfake Detection Inference (Ensemble Support)")
101
+ parser.add_argument("--source", type=str, required=True, help="Path to image or directory")
102
+ parser.add_argument("--checkpoints", type=str, default="results/checkpoints", help="Path to checkpoint file, list of files (comma-separated), or directory")
103
+ parser.add_argument("--device", type=str, default=Config.DEVICE, help="Device to use (cuda/mps/cpu)")
104
+ args = parser.parse_args()
105
+
106
+ device = torch.device(args.device)
107
+ print(f"Using device: {device}")
108
+
109
+ # Load Models
110
+ models = load_models(args.checkpoints, device)
111
+ transform = get_transform()
112
+
113
+ # Process Source
114
+ if os.path.isdir(args.source):
115
+ files = glob.glob(os.path.join(args.source, "*.*"))
116
+ # Filter images
117
+ files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
118
+ else:
119
+ files = [args.source]
120
+
121
+ print(f"Processing {len(files)} images with {len(models)} model(s)...")
122
+ print("-" * 65)
123
+ print(f"{'Image Name':<40} | {'Prediction':<10} | {'Confidence':<10}")
124
+ print("-" * 65)
125
+
126
+ for file_path in files:
127
+ prob, error = predict_ensemble(models, file_path, device, transform)
128
+ if error:
129
+ print(f"{os.path.basename(file_path):<40} | ERROR: {error}")
130
+ continue
131
+
132
+ is_fake = prob > 0.5
133
+ label = "FAKE" if is_fake else "REAL"
134
+ confidence = prob if is_fake else 1 - prob
135
+
136
+ print(f"{os.path.basename(file_path):<40} | {label:<10} | {confidence:.2%}")
137
+
138
+ if __name__ == "__main__":
139
+ main()
model/src/models.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ import numpy as np
6
+ from src.utils import get_fft_feature
7
+
8
+ class RGBBranch(nn.Module):
9
+ def __init__(self, pretrained=True):
10
+ super().__init__()
11
+ # EfficientNet V2 Small: Robust and efficient spatial features
12
+ weights = models.EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
13
+ self.net = models.efficientnet_v2_s(weights=weights)
14
+ # Extract features before classification head
15
+ self.features = self.net.features
16
+ self.avgpool = self.net.avgpool
17
+ self.out_dim = 1280
18
+
19
+ def forward(self, x):
20
+ x = self.features(x)
21
+ x = self.avgpool(x)
22
+ x = torch.flatten(x, 1)
23
+ return x
24
+
25
+ class FreqBranch(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ # Simple CNN to analyze frequency domain patterns
29
+ self.net = nn.Sequential(
30
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
31
+ nn.BatchNorm2d(32),
32
+ nn.ReLU(),
33
+ nn.MaxPool2d(2),
34
+
35
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
36
+ nn.BatchNorm2d(64),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(2),
39
+
40
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
41
+ nn.BatchNorm2d(128),
42
+ nn.ReLU(),
43
+ nn.AdaptiveAvgPool2d((1,1))
44
+ )
45
+ self.out_dim = 128
46
+
47
+ def forward(self, x):
48
+ return torch.flatten(self.net(x), 1)
49
+
50
+ class PatchBranch(nn.Module):
51
+ def __init__(self):
52
+ super().__init__()
53
+ # Analyzes local patches for inconsistencies
54
+ # Shared lightweight CNN for each patch
55
+ self.patch_encoder = nn.Sequential(
56
+ nn.Conv2d(3, 16, kernel_size=3, padding=1),
57
+ nn.ReLU(),
58
+ nn.MaxPool2d(2), # 64 -> 32
59
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
60
+ nn.ReLU(),
61
+ nn.MaxPool2d(2), # 32 -> 16
62
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
63
+ nn.ReLU(),
64
+ nn.AdaptiveAvgPool2d((1,1))
65
+ )
66
+ self.out_dim = 64
67
+
68
+ def forward(self, x):
69
+ # x: (B, 3, 256, 256)
70
+ # Create 4x4=16 patches of size 64x64
71
+ # Unfold logic: kernel_size=64, stride=64
72
+ patches = x.unfold(2, 64, 64).unfold(3, 64, 64)
73
+ # patches shape: (B, 3, 4, 4, 64, 64)
74
+ B, C, H_grid, W_grid, H_patch, W_patch = patches.shape
75
+
76
+ # Merge batch and grid dimensions for parallel processing
77
+ patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
78
+ patches = patches.view(B * H_grid * W_grid, C, H_patch, W_patch)
79
+
80
+ # Encode
81
+ feats = self.patch_encoder(patches) # (B*16, 64, 1, 1)
82
+ feats = torch.flatten(feats, 1) # (B*16, 64)
83
+
84
+ # Aggregate back to B
85
+ feats = feats.view(B, H_grid * W_grid, -1) # (B, 16, 64)
86
+
87
+ # Max pool over patches to capture the "most fake" patch signal
88
+ feats_max, _ = torch.max(feats, dim=1) # (B, 64)
89
+
90
+ return feats_max
91
+
92
+ class ViTBranch(nn.Module):
93
+ def __init__(self, pretrained=True):
94
+ super().__init__()
95
+ # Swin Transformer Tiny: Capture long-range dependencies
96
+ weights = models.Swin_V2_T_Weights.DEFAULT if pretrained else None
97
+ self.net = models.swin_v2_t(weights=weights)
98
+
99
+ # Replace head with Identity to get features
100
+ self.out_dim = self.net.head.in_features
101
+ self.net.head = nn.Identity()
102
+
103
+ def forward(self, x):
104
+ return self.net(x)
105
+
106
+ class DeepfakeDetector(nn.Module):
107
+ def __init__(self, pretrained=True):
108
+ super().__init__()
109
+ self.rgb_branch = RGBBranch(pretrained)
110
+ self.freq_branch = FreqBranch()
111
+ self.patch_branch = PatchBranch()
112
+ self.vit_branch = ViTBranch(pretrained)
113
+
114
+ input_dim = (self.rgb_branch.out_dim +
115
+ self.freq_branch.out_dim +
116
+ self.patch_branch.out_dim +
117
+ self.vit_branch.out_dim)
118
+
119
+ # Confidence-based fusion head
120
+ self.classifier = nn.Sequential(
121
+ nn.Linear(input_dim, 512),
122
+ nn.BatchNorm1d(512),
123
+ nn.ReLU(),
124
+ nn.Dropout(0.5),
125
+ nn.Linear(512, 1)
126
+ )
127
+
128
+ def forward(self, x):
129
+ # 1. Spatial Analysis
130
+ rgb_feat = self.rgb_branch(x)
131
+
132
+ # 2. Frequency Analysis
133
+ freq_img = get_fft_feature(x)
134
+ freq_feat = self.freq_branch(freq_img)
135
+
136
+ # 3. Patch Analysis (Local Inconsistencies)
137
+ patch_feat = self.patch_branch(x)
138
+
139
+ # 4. Global Consistency (ViT)
140
+ vit_feat = self.vit_branch(x)
141
+
142
+ # 5. Feature Fusion
143
+ combined = torch.cat([rgb_feat, freq_feat, patch_feat, vit_feat], dim=1)
144
+
145
+ return self.classifier(combined)
146
+
147
+ def get_heatmap(self, x):
148
+ """Generate Grad-CAM heatmap for the input image"""
149
+ # We'll use the RGB branch for visualization as it contains spatial features
150
+ # Enable gradients for the input if needed, though typically we hook into layers
151
+
152
+ # 1. Forward pass through RGB branch
153
+ # We need to register a hook on the last conv layer of the efficientnet features
154
+ # Target layer: self.rgb_branch.features[-1] (the last block)
155
+
156
+ gradients = []
157
+ activations = []
158
+
159
+ def backward_hook(module, grad_input, grad_output):
160
+ gradients.append(grad_output[0])
161
+
162
+ def forward_hook(module, input, output):
163
+ activations.append(output)
164
+
165
+ # Register hooks on the last convolutional layer of RGB branch
166
+ target_layer = self.rgb_branch.features[-1]
167
+ hook_b = target_layer.register_full_backward_hook(backward_hook)
168
+ hook_f = target_layer.register_forward_hook(forward_hook)
169
+
170
+ # Forward pass
171
+ logits = self(x)
172
+ pred_idx = 0 # Binary classification, output is scalar logic
173
+
174
+ # Backward pass
175
+ self.zero_grad()
176
+ logits.backward(retain_graph=True)
177
+
178
+ # Get gradients and activations
179
+ pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])
180
+ activation = activations[0][0]
181
+
182
+ # Weight activations by gradients (Grad-CAM)
183
+ for i in range(activation.shape[0]):
184
+ activation[i, :, :] *= pooled_gradients[i]
185
+
186
+ heatmap = torch.mean(activation, dim=0).cpu().detach().numpy()
187
+ heatmap = np.maximum(heatmap, 0) # ReLU
188
+
189
+ # Normalize
190
+ if np.max(heatmap) != 0:
191
+ heatmap /= np.max(heatmap)
192
+
193
+ # Remove hooks
194
+ hook_b.remove()
195
+ hook_f.remove()
196
+
197
+ return heatmap
model/src/test_dataloading.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from src.config import Config
4
+ from src.dataset import DeepfakeDataset
5
+
6
+ def test_dataloading():
7
+ print("Testing Data Loading & Splitting Logic...")
8
+ Config.setup()
9
+
10
+ print(f"Data Path: {Config.TRAIN_DATA_PATH}")
11
+
12
+ # 1. Test Scan
13
+ paths, labels = DeepfakeDataset.scan_directory(Config.TRAIN_DATA_PATH)
14
+ total_files = len(paths)
15
+ print(f"Total images found: {total_files}")
16
+
17
+ if total_files == 0:
18
+ print("[FAIL] No images found! Check path.")
19
+ return
20
+
21
+ # 2. Simulate Split Logic
22
+ combined = list(zip(paths, labels))
23
+ random.shuffle(combined)
24
+ split_idx = int(len(combined) * 0.8)
25
+ train_data = combined[:split_idx]
26
+ val_data = combined[split_idx:]
27
+
28
+ print(f"Train Split: {len(train_data)} images")
29
+ print(f"Val Split: {len(val_data)} images")
30
+
31
+ # 3. Test Dataset Initialization
32
+ try:
33
+ train_paths, train_labels = zip(*train_data)
34
+ ds = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
35
+ print(f"[Pass] Dataset initialized with {len(ds)} samples.")
36
+
37
+ # Test Get Item
38
+ img, lbl = ds[0]
39
+ print(f"[Pass] Loaded sample image. Shape: {img.shape}, Label: {lbl}")
40
+ except Exception as e:
41
+ print(f"[FAIL] Dataset initialization or loading error: {e}")
42
+ return
43
+
44
+ print("\nSUCCESS: Data loading verification passed!")
45
+
46
+ if __name__ == "__main__":
47
+ test_dataloading()
model/src/test_dryrun.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.models import DeepfakeDetector
4
+ from src.config import Config
5
+
6
+ def test_model_architecture():
7
+ print("Testing DeepfakeDetector Architecture...")
8
+
9
+ # Check device
10
+ device = torch.device("cpu") # Test on CPU for simplicity or Config.DEVICE
11
+ print(f"Device: {device}")
12
+
13
+ # Initialize Model
14
+ try:
15
+ model = DeepfakeDetector(pretrained=False).to(device)
16
+ print("[Pass] Model Initialization")
17
+ except Exception as e:
18
+ print(f"[Fail] Model Initialization: {e}")
19
+ return
20
+
21
+ # Create dummy input
22
+ batch_size = 2
23
+ x = torch.randn(batch_size, 3, Config.IMAGE_SIZE, Config.IMAGE_SIZE).to(device)
24
+ print(f"Input Shape: {x.shape}")
25
+
26
+ # Forward Pass
27
+ try:
28
+ out = model(x)
29
+ print(f"Output Shape: {out.shape}")
30
+
31
+ if out.shape == (batch_size, 1):
32
+ print("[Pass] Output Shape Correct")
33
+ else:
34
+ print(f"[Fail] Output Shape Incorrect. Expected ({batch_size}, 1), got {out.shape}")
35
+ except Exception as e:
36
+ print(f"[Fail] Forward Pass: {e}")
37
+ # Debug trace
38
+ import traceback
39
+ traceback.print_exc()
40
+ return
41
+
42
+ # Loss and Backward
43
+ try:
44
+ criterion = nn.BCEWithLogitsLoss()
45
+ target = torch.ones(batch_size, 1).to(device)
46
+ loss = criterion(out, target)
47
+ loss.backward()
48
+ print(f"[Pass] Backward Pass (Loss: {loss.item():.4f})")
49
+ except Exception as e:
50
+ print(f"[Fail] Backward Pass: {e}")
51
+ return
52
+
53
+ print("\nSUCCESS: Model architecture verification passed!")
54
+
55
+ if __name__ == "__main__":
56
+ test_model_architecture()
model/src/train.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ import random
8
+ import ssl
9
+ # Disable SSL verification for downloading pretrained weights
10
+ ssl._create_default_https_context = ssl._create_unverified_context
11
+ from torch.cuda.amp import GradScaler, autocast
12
+
13
+ from src.config import Config
14
+ from src.models import DeepfakeDetector
15
+ from src.dataset import DeepfakeDataset
16
+
17
+ try:
18
+ from safetensors.torch import save_file, load_file
19
+ SAFETENSORS_AVAILABLE = True
20
+ except ImportError:
21
+ SAFETENSORS_AVAILABLE = False
22
+ print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
23
+
24
+ def train():
25
+ # Setup
26
+ Config.setup()
27
+ device = torch.device(Config.DEVICE)
28
+
29
+ # --- Data Loading with Automatic Split ---
30
+ if Config.TRAIN_DATA_PATH == Config.TEST_DATA_PATH:
31
+ print("Train and Test paths are identical. Performing automatic 80/20 shuffle split...")
32
+ all_paths, all_labels = DeepfakeDataset.scan_directory(Config.TRAIN_DATA_PATH)
33
+
34
+ if len(all_paths) == 0:
35
+ print(f"No images found in {Config.TRAIN_DATA_PATH}")
36
+ return
37
+
38
+ # Combine and shuffle
39
+ combined = list(zip(all_paths, all_labels))
40
+ random.shuffle(combined)
41
+
42
+ split_idx = int(len(combined) * 0.8)
43
+ train_data = combined[:split_idx]
44
+ val_data = combined[split_idx:]
45
+
46
+ train_paths, train_labels = zip(*train_data)
47
+ val_paths, val_labels = zip(*val_data)
48
+
49
+ train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
50
+ val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
51
+ else:
52
+ # Standard folder-based loading
53
+ train_dataset = DeepfakeDataset(root_dir=Config.TRAIN_DATA_PATH, phase='train')
54
+ val_dataset = DeepfakeDataset(root_dir=Config.TEST_DATA_PATH, phase='val')
55
+
56
+ # Dataloaders
57
+ train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
58
+ num_workers=Config.NUM_WORKERS,
59
+ pin_memory=True if device.type=='cuda' else False,
60
+ persistent_workers=True if Config.NUM_WORKERS > 0 else False)
61
+ val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
62
+ num_workers=Config.NUM_WORKERS,
63
+ pin_memory=True if device.type=='cuda' else False,
64
+ persistent_workers=True if Config.NUM_WORKERS > 0 else False)
65
+
66
+ # Model
67
+ print("Initializing Multi-Branch DeepfakeDetector...")
68
+ model = DeepfakeDetector(pretrained=True).to(device)
69
+
70
+ # Optimization
71
+ criterion = nn.BCEWithLogitsLoss()
72
+ optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
73
+ # Optimization
74
+ criterion = nn.BCEWithLogitsLoss()
75
+ optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
76
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
77
+
78
+ # Enable AMP only for CUDA (Windows NVIDIA)
79
+ use_amp = (Config.DEVICE == 'cuda')
80
+ scaler = GradScaler() if use_amp else None
81
+ if use_amp:
82
+ print("🚀 Mixed Precision (AMP) Enabled for RTX GPU")
83
+ else:
84
+ print("🐌 Standard Precision (No AMP) for CPU/MPS")
85
+
86
+ # Resume from checkpoint if exists
87
+ start_epoch = 0
88
+ best_acc = 0.0
89
+
90
+ # Priority:
91
+ # 1. best_model.safetensors (if we crashed mid-training)
92
+ # 2. patched_model.safetensors (the model we want to improve)
93
+
94
+ resume_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.safetensors")
95
+ if not os.path.exists(resume_path):
96
+ # Look for latest epoch checkpoint
97
+ import glob
98
+ import re
99
+ checkpoints = glob.glob(os.path.join(Config.CHECKPOINT_DIR, "checkpoint_ep*.safetensors"))
100
+ if checkpoints:
101
+ # Sort by epoch number
102
+ def get_epoch(p):
103
+ match = re.search(r"checkpoint_ep(\d+)", p)
104
+ return int(match.group(1)) if match else 0
105
+
106
+ latest_ckpt = max(checkpoints, key=get_epoch)
107
+ resume_path = latest_ckpt
108
+ start_epoch = get_epoch(latest_ckpt)
109
+ print(f"🔄 Auto-Resuming from latest epoch: {start_epoch}")
110
+ else:
111
+ resume_path = os.path.join(Config.CHECKPOINT_DIR, "patched_model.safetensors")
112
+
113
+ if os.path.exists(resume_path):
114
+ print(f"\n🔄 Found existing checkpoint: {resume_path}")
115
+ print("Auto-resuming to FINETUNE this model...")
116
+
117
+ try:
118
+ if resume_path.endswith(".safetensors") and SAFETENSORS_AVAILABLE:
119
+ state_dict = load_file(resume_path)
120
+ else:
121
+ state_dict = torch.load(resume_path, map_location=device)
122
+
123
+ # Use strict=False to allow for minor architecture changes or missing keys
124
+ model.load_state_dict(state_dict, strict=False)
125
+ print("✅ Weights loaded. Starting Fine-Tuning.")
126
+ except Exception as e:
127
+ print(f"⚠ Failed to load checkpoint: {e}")
128
+ print("Starting from ImageNet weights.")
129
+
130
+ # Loop
131
+
132
+ for epoch in range(start_epoch, Config.EPOCHS):
133
+ model.train()
134
+ train_loss = 0.0
135
+ train_correct = 0
136
+ train_total = 0
137
+
138
+ loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.EPOCHS}")
139
+ for images, labels in loop:
140
+ images = images.to(device)
141
+ labels = labels.to(device).unsqueeze(1)
142
+
143
+ optimizer.zero_grad()
144
+
145
+ if use_amp:
146
+ with autocast():
147
+ outputs = model(images)
148
+ loss = criterion(outputs, labels)
149
+
150
+ scaler.scale(loss).backward()
151
+ scaler.step(optimizer)
152
+ scaler.update()
153
+ else:
154
+ # Standard training for Mac/CPU
155
+ outputs = model(images)
156
+ loss = criterion(outputs, labels)
157
+ loss.backward()
158
+ optimizer.step()
159
+
160
+ train_loss += loss.item()
161
+ preds = (torch.sigmoid(outputs) > 0.5).float()
162
+ correct = (preds == labels).sum().item()
163
+ train_correct += correct
164
+ train_total += labels.size(0)
165
+
166
+ loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0))
167
+
168
+ train_acc = train_correct / train_total if train_total > 0 else 0
169
+ print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
170
+
171
+ # Save checkpoint after every epoch
172
+ save_checkpoint(model, epoch+1, train_acc, best=False)
173
+
174
+ # Validation
175
+ if len(val_dataset) > 0:
176
+ val_loss, val_acc = validate(model, val_loader, criterion, device)
177
+ print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
178
+
179
+ # Save best model if validation accuracy improved
180
+ if val_acc > best_acc:
181
+ best_acc = val_acc
182
+ print(f"⭐ New best model! Validation Accuracy: {val_acc:.4f}")
183
+ save_checkpoint(model, epoch+1, val_acc, best=True)
184
+
185
+ scheduler.step()
186
+
187
+ print(f"\n🎉 Training Complete!")
188
+ print(f"Best Validation Accuracy: {best_acc:.4f}")
189
+
190
+ def validate(model, loader, criterion, device):
191
+ model.eval()
192
+ val_loss = 0.0
193
+ correct = 0
194
+ total = 0
195
+
196
+ with torch.no_grad():
197
+ for images, labels in loader:
198
+ images = images.to(device)
199
+ labels = labels.to(device).unsqueeze(1)
200
+
201
+ outputs = model(images)
202
+ loss = criterion(outputs, labels)
203
+
204
+ val_loss += loss.item()
205
+ preds = (torch.sigmoid(outputs) > 0.5).float()
206
+ correct += (preds == labels).sum().item()
207
+ total += labels.size(0)
208
+
209
+ return val_loss / len(loader), correct / total
210
+
211
+ def save_checkpoint(model, epoch, acc, best=False):
212
+ state_dict = model.state_dict()
213
+ name = "best_model.safetensors" if best else f"checkpoint_ep{epoch}.safetensors"
214
+ path = os.path.join(Config.CHECKPOINT_DIR, name)
215
+
216
+ if SAFETENSORS_AVAILABLE:
217
+ try:
218
+ # Try with shared tensors support
219
+ from safetensors.torch import save_model
220
+ save_model(model, path)
221
+ print(f"Saved Checkpoint: {path}")
222
+
223
+ # 📝 Auto-Log to History
224
+ try:
225
+ from datetime import datetime
226
+ log_path = os.path.join(Config.PROJECT_ROOT, "TRAINING_HISTORY.md")
227
+ timestamp = datetime.now().strftime("%Y-%m-%d | %I:%M %p")
228
+
229
+ # Create file with header if doesn't exist
230
+ if not os.path.exists(log_path):
231
+ with open(log_path, "w", encoding="utf-8") as f:
232
+ f.write("# 📜 Training History Log\n\n")
233
+ f.write("| Date | Time | Model Name | Dataset | Epochs | Accuracy | Loss | Status |\n")
234
+ f.write("| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n")
235
+
236
+ # Append Entry to Summary Log
237
+ with open(log_path, "a", encoding="utf-8") as f:
238
+ # Format: Date | Time | Name | Dataset | Epoch | Acc | Loss | Status
239
+ dataset_name = os.path.basename(Config.DATA_DIR)
240
+ entry = f"| **{timestamp.split(' | ')[0]}** | {timestamp.split(' | ')[1]} | {name} | {dataset_name} | {epoch} | {acc*100:.2f}% | N/A | ✅ Saved |\n"
241
+ f.write(entry)
242
+ print(f"📝 Logged to TRAINING_HISTORY.md")
243
+
244
+ # 📝 Detailed Lab Notebook Logging
245
+ detail_path = os.path.join(Config.PROJECT_ROOT, "DETAILED_HISTORY.md")
246
+ with open(detail_path, "a", encoding="utf-8") as f:
247
+ f.write(f"\n## Model: {name} (Epoch {epoch})\n")
248
+ f.write(f"| Feature | Detail |\n| :--- | :--- |\n")
249
+ f.write(f"| **Date** | {timestamp} |\n")
250
+ f.write(f"| **Training Accuracy** | {acc*100:.2f}% |\n")
251
+ f.write(f"| **Dataset** | {Config.DATA_DIR} |\n")
252
+ f.write(f"| **Batch Size** | {Config.BATCH_SIZE} |\n")
253
+ f.write(f"| **Optimizer** | AdamW (lr={Config.LEARNING_RATE}) |\n")
254
+ f.write(f"| **Device** | {Config.DEVICE.upper()} |\n")
255
+ f.write("\n---\n")
256
+ print(f"📘 Detailed log written to DETAILED_HISTORY.md")
257
+
258
+ except Exception as e:
259
+ print(f"⚠️ Failed to write log: {e}")
260
+
261
+ except Exception as e:
262
+ # Fallback to regular torch save if safetensors fails
263
+ print(f"SafeTensors save failed ({e}), falling back to .pth format")
264
+ torch.save(state_dict, path.replace(".safetensors", ".pth"))
265
+ print(f"Saved Checkpoint (Legacy): {path.replace('.safetensors', '.pth')}")
266
+ else:
267
+ torch.save(state_dict, path.replace(".safetensors", ".pth"))
268
+ print(f"Saved Checkpoint (Legacy): {path}")
269
+
270
+ if __name__ == "__main__":
271
+ train()
model/src/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+
5
+ def get_fft_feature(x):
6
+ """
7
+ Computes the Log-Magnitude Spectrum of the input images.
8
+ Args:
9
+ x (torch.Tensor): Input images of shape (B, C, H, W)
10
+ Returns:
11
+ torch.Tensor: Log-magnitude spectrum of shape (B, C, H, W)
12
+ """
13
+ if x.dim() == 3:
14
+ x = x.unsqueeze(0)
15
+
16
+ # Compute 2D FFT
17
+ fft = torch.fft.fft2(x, norm='ortho')
18
+
19
+ # Compute magnitude
20
+ mag = torch.abs(fft)
21
+
22
+ # Apply log scale (add epsilon for stability)
23
+ mag = torch.log(mag + 1e-6)
24
+
25
+ # Shift zero-frequency component to the center of the spectrum
26
+ mag = torch.fft.fftshift(mag, dim=(-2, -1))
27
+
28
+ return mag
29
+
30
+ def min_max_normalize(tensor):
31
+ """
32
+ Min-max normalization for visualization or stable training provided tensor.
33
+ """
34
+ min_val = tensor.min()
35
+ max_val = tensor.max()
36
+ return (tensor - min_val) / (max_val - min_val + 1e-8)
37
+
model/src/video_inference.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+
7
+ def process_video(video_path, model, transform, device, frames_per_second=1):
8
+ """
9
+ Process a video file frame-by-frame using the deepfake detection model.
10
+
11
+ Args:
12
+ video_path (str): Path to the video file.
13
+ model (torch.nn.Module): Loaded PyTorch model.
14
+ transform (callable): Albumentations transform pipeline.
15
+ device (torch.device): Device to run inference on.
16
+ frames_per_second (int): Number of frames to sample per second of video.
17
+ Default is 1 to keep processing fast.
18
+
19
+ Returns:
20
+ dict: Aggregated results including verdict, average confidence, and frame-level details.
21
+ """
22
+ if model is None:
23
+ return {"error": "Model not loaded"}
24
+
25
+ cap = cv2.VideoCapture(video_path)
26
+ if not cap.isOpened():
27
+ return {"error": "Could not open video file"}
28
+
29
+ # specific video properties
30
+ fps = cap.get(cv2.CAP_PROP_FPS)
31
+ if fps <= 0: fps = 30 # Fallback
32
+
33
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
34
+ duration = total_frames / fps
35
+
36
+ # Calculate sampling interval (step size)
37
+ # If we want 1 frame per second, we step by 'fps' frames
38
+ step = int(fps / frames_per_second)
39
+ if step < 1: step = 1
40
+
41
+ frame_indices = []
42
+ probs = []
43
+
44
+ print(f"Processing video: {video_path}")
45
+ print(f"Duration: {duration:.2f}s, FPS: {fps}, Total Frames: {total_frames}")
46
+ print(f"Sampling every {step} frames...")
47
+
48
+ count = 0
49
+ processed_count = 0
50
+
51
+ suspicious_frames = [] # Store frames with high fake probability
52
+
53
+ while cap.isOpened():
54
+ ret, frame = cap.read()
55
+ if not ret:
56
+ break
57
+
58
+ if count % step == 0:
59
+ # Process this frame
60
+ try:
61
+ # Convert BGR (OpenCV) to RGB
62
+ image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
+
64
+ # --- Face Extraction ---
65
+ # Load Haar Cascade (lazy load)
66
+ if not hasattr(process_video, "face_cascade"):
67
+ try:
68
+ cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
69
+ process_video.face_cascade = cv2.CascadeClassifier(cascade_path)
70
+ except:
71
+ process_video.face_cascade = None
72
+
73
+ face_crop = None
74
+ if process_video.face_cascade:
75
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
76
+ faces = process_video.face_cascade.detectMultiScale(
77
+ gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60)
78
+ )
79
+
80
+ if len(faces) > 0:
81
+ # Find largest face
82
+ largest_face = max(faces, key=lambda rect: rect[2] * rect[3])
83
+ x, y, w, h = largest_face
84
+
85
+ # Add margin (20%)
86
+ margin = int(max(w, h) * 0.2)
87
+ x_start = max(x - margin, 0)
88
+ y_start = max(y - margin, 0)
89
+ x_end = min(x + w + margin, frame.shape[1])
90
+ y_end = min(y + h + margin, frame.shape[0])
91
+
92
+ face_crop = image[y_start:y_end, x_start:x_end]
93
+
94
+ # Use face crop if found, otherwise use full image
95
+ input_image = face_crop if face_crop is not None else image
96
+
97
+ # Apply transforms
98
+ augmented = transform(image=input_image)
99
+ image_tensor = augmented['image'].unsqueeze(0).to(device)
100
+
101
+ # Inference
102
+ with torch.no_grad():
103
+ logits = model(image_tensor)
104
+ prob = torch.sigmoid(logits).item()
105
+
106
+ probs.append(prob)
107
+ frame_indices.append(count)
108
+ processed_count += 1
109
+
110
+ # If highly fake, store metadata (timestamp)
111
+ if prob > 0.5:
112
+ timestamp = count / fps
113
+ suspicious_frames.append({
114
+ "timestamp": round(timestamp, 2),
115
+ "frame_index": count,
116
+ "fake_prob": round(prob, 4)
117
+ })
118
+
119
+ except Exception as e:
120
+ print(f"Error processing frame {count}: {e}")
121
+
122
+ count += 1
123
+
124
+ cap.release()
125
+
126
+ if processed_count == 0:
127
+ return {"error": "No frames processed"}
128
+
129
+ # Aggregation
130
+ avg_prob = sum(probs) / len(probs)
131
+ max_prob = max(probs)
132
+ fake_frame_count = len([p for p in probs if p > 0.6]) # Stricter frame threshold
133
+ fake_ratio = fake_frame_count / processed_count
134
+
135
+ # Verdict Logic (Tuned for High Efficiency Model)
136
+ # The new model is detecting everything as fake, so we need stricter rules.
137
+
138
+ # 1. Standard Average Check (shifted)
139
+ cond1 = avg_prob > 0.65
140
+
141
+ # 2. Density Check: Require at least 15% of frames to be strictly fake
142
+ # Was 5%, which is too low for a sensitive model
143
+ cond2 = fake_ratio > 0.15 and max_prob > 0.7
144
+
145
+ # 3. Peak Check: Only flag single-frame anomalies if EXTREMELY suspicious
146
+ cond3 = max_prob > 0.95
147
+
148
+ is_fake = cond1 or cond2 or cond3
149
+
150
+ verdict = "FAKE" if is_fake else "REAL"
151
+
152
+ # Confidence Calculation
153
+ if is_fake:
154
+ confidence = max(max_prob, 0.6)
155
+ else:
156
+ confidence = 1 - avg_prob
157
+
158
+ return {
159
+ "type": "video",
160
+ "prediction": verdict,
161
+ "confidence": float(confidence),
162
+ "avg_fake_prob": float(avg_prob),
163
+ "max_fake_prob": float(max_prob),
164
+ "fake_frame_ratio": float(fake_ratio),
165
+ "processed_frames": processed_count,
166
+ "duration": float(duration),
167
+ "timeline": [
168
+ {"time": round(i / fps, 2), "prob": round(p, 3)}
169
+ for i, p in zip(frame_indices, probs)
170
+ ],
171
+ "suspicious_frames": suspicious_frames[:10] # Top 10 suspicious moments
172
+ }