File size: 12,823 Bytes
0966609 655ee8e 0966609 a68ce06 0966609 655ee8e a68ce06 0966609 12bb6d1 0966609 ff25bf3 0966609 ff25bf3 0966609 ff25bf3 0966609 ff25bf3 0966609 ff25bf3 0966609 ff25bf3 e0cfe11 655ee8e e0cfe11 0966609 e0cfe11 0966609 ff25bf3 0966609 ff25bf3 860d30f ff25bf3 0966609 46337a4 ff25bf3 0966609 8ee24ff 0966609 b840495 2a8b4dc 0966609 2a8b4dc 0966609 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 | from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import sys
import os
# Add model directory to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'model')))
import datetime
import torch
import cv2
import os
import numpy as np
import ssl
import base64
from werkzeug.utils import secure_filename
import io
from PIL import Image
from src import video_inference
# Disable SSL verification
ssl._create_default_https_context = ssl._create_unverified_context
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations.pytorch import ToTensorV2
from src.models import DeepfakeDetector
from src.config import Config
import database
safetensors_import_error = None
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
print("β
safetensors library loaded successfully")
except ImportError as e:
SAFETENSORS_AVAILABLE = False
safetensors_import_error = str(e)
print(f"β Failed to import safetensors: {e}")
app = Flask(__name__, static_folder='../frontend', static_url_path='')
CORS(app, resources={r"/*": {"origins": "*"}})
# Configuration
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), 'uploads')
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp', 'mp4', 'avi', 'mov', 'webm'}
HISTORY_FOLDER = os.path.join(os.path.dirname(__file__), 'history_uploads')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(HISTORY_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# Global model, transform, and error state
device = torch.device(Config.DEVICE)
model = None
transform = None
loading_error = None
def get_transform():
return A.Compose([
A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def load_model():
"""Load the trained deepfake detection model"""
global model, transform, loading_error
checkpoint_dir = Config.CHECKPOINT_DIR
# Explicitly target the model requested by the user
target_model_name = "best_model.safetensors"
checkpoint_path = os.path.join(checkpoint_dir, target_model_name)
print(f"Using device: {device}")
# Initialize with pretrained=True to ensure missing keys (frozen layers) have valid ImageNet weights
# instead of random noise. This fixes the "random prediction" issue when the checkpoint
# only contains finetuned layers.
try:
model = DeepfakeDetector(pretrained=True)
model.to(device)
model.eval()
except Exception as e:
loading_error = f"Failed to init model architecture: {str(e)}"
print(loading_error)
model = None
return None, None
# Check if file exists first
if not os.path.exists(checkpoint_path):
loading_error = f"File not found: {checkpoint_path}. Contents of {checkpoint_dir}: {os.listdir(checkpoint_dir) if os.path.exists(checkpoint_dir) else 'Dir missing'}"
print(f"β {loading_error}")
model = None
transform = get_transform()
return model, transform
try:
print(f"Loading checkpoint: {checkpoint_path}")
if checkpoint_path.endswith(".safetensors"):
if SAFETENSORS_AVAILABLE:
state_dict = load_file(checkpoint_path)
else:
# Fallback to torch.load even for safetensors if they are actually pickles
# PyTorch 2.6+ requires weights_only=False for legacy pickles
print("WARNING: safetensors not found, attempting torch.load with weights_only=False")
# If we are failing here, it's likely because we couldn't import safetensors.
# Let's save that info.
loading_error = f"Safetensors import failed: {safetensors_import_error}. Fallback torch.load failed."
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
else:
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"β
Model loaded successfully!")
loading_error = None # Clear error on success
except Exception as e:
loading_error = f"Error loading checkpoint: {str(e)}"
if safetensors_import_error:
loading_error += f" | NOTE: Safetensors lib failed to import: {safetensors_import_error}"
print(f"β {loading_error}")
model = None
transform = get_transform()
return model, transform
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def predict_image(image_path):
"""Make prediction on a single image"""
if model is None:
return None, f"Model Error: {loading_error}"
try:
# Read and preprocess image
image = cv2.imread(image_path)
if image is None:
return None, "Error: Could not read image"
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
augmented = transform(image=image)
image_tensor = augmented['image'].unsqueeze(0).to(device)
# Make prediction
logits = model(image_tensor)
prob = torch.sigmoid(logits).item()
# Generate Heatmap
heatmap = model.get_heatmap(image_tensor)
# Process Heatmap for Visualization
# Resize to original image size
heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Superimpose
# Heatmap is BGR (from cv2), Image is RGB. Convert Image to BGR.
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
superimposed_img = heatmap * 0.4 + image_bgr * 0.6
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
# Encode to Base64
_, buffer = cv2.imencode('.jpg', superimposed_img)
heatmap_b64 = base64.b64encode(buffer).decode('utf-8')
is_fake = prob > 0.5
label = "FAKE" if is_fake else "REAL"
confidence = prob if is_fake else 1 - prob
return {
'prediction': label,
'confidence': float(confidence),
'fake_probability': float(prob),
'real_probability': float(1 - prob),
'heatmap': heatmap_b64
}, None
except Exception as e:
return None, str(e)
@app.route('/')
def index():
"""Serve the simple demo frontend"""
return send_from_directory('static', 'demo.html')
@app.route('/history_uploads/<path:filename>')
def serve_history_image(filename):
"""Serve history images"""
return send_from_directory(HISTORY_FOLDER, filename)
@app.route('/api/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'model_loaded': model is not None,
'device': str(device)
})
@app.route('/api/predict', methods=['POST'])
def predict():
"""Handle image upload and prediction"""
try:
# Check if file is present
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Invalid file type. Allowed types: png, jpg, jpeg, webp'}), 400
# Save file
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Make prediction
result, error = predict_image(filepath)
if result is None:
return jsonify({'error': error}), 500
# Cleanup - Delete the upload immediately
try:
if os.path.exists(filepath):
os.remove(filepath)
except:
pass
return jsonify(result)
# Clean up uploaded file
try:
os.remove(filepath)
except:
pass
return jsonify(result)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/predict_video', methods=['POST'])
def predict_video():
"""Handle video upload and prediction"""
try:
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Invalid file type'}), 400
# Save file
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Process Video
# Note: process_video needs sys.path to be correct to import models inside it if it was standalone,
# but here we pass the already loaded 'model' object.
if model is None:
return jsonify({'error': 'Model not loaded'}), 500
result = video_inference.process_video(filepath, model, transform, device)
if "error" in result:
return jsonify(result), 500
# Save to History (Using the first frame or a placeholder icon for now?)
# For video, we might want to save the video file itself to history_uploads
# or just a thumbnail. Let's save the video for now.
import shutil
history_filename = f"scan_{int(datetime.datetime.now().timestamp())}_{filename}"
history_path = os.path.join(HISTORY_FOLDER, history_filename)
shutil.copy(filepath, history_path)
relative_path = f"history_uploads/{history_filename}"
# Add to database
# Note: The database 'add_scan' might expect image-specific fields.
# We'll re-use 'fake_prob' as 'avg_fake_prob'
database.add_scan(
filename=filename,
prediction=result['prediction'],
confidence=result['confidence'],
fake_prob=result['avg_fake_prob'],
real_prob=1 - result['avg_fake_prob'],
image_path=relative_path
)
# Clean up
try:
os.remove(filepath)
except:
pass
# Add video URL for frontend playback
result['video_url'] = relative_path
return jsonify(result)
except Exception as e:
print(f"Video Error: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/history', methods=['GET'])
def get_history():
"""Get all past scans"""
history = database.get_history()
history = database.get_history()
return jsonify(history)
@app.route('/api/history/<int:scan_id>', methods=['DELETE'])
def delete_scan(scan_id):
"""Delete a specific scan"""
if database.delete_scan(scan_id):
return jsonify({'message': 'Scan deleted'})
return jsonify({'error': 'Failed to delete scan'}), 500
@app.route('/api/history', methods=['DELETE'])
def clear_history():
"""Clear all history"""
if database.clear_history():
return jsonify({'message': 'History cleared'})
return jsonify({'error': 'Failed to clear history'}), 500
@app.route('/api/model-info', methods=['GET'])
def model_info():
"""Return model information"""
return jsonify({
'model_name': 'DeepGuard: Advanced Deepfake Detector',
'architecture': 'Hybrid CNN-ViT',
'components': {
'RGB Analysis': Config.USE_RGB,
'Frequency Domain': Config.USE_FREQ,
'Patch-based Detection': Config.USE_PATCH,
'Vision Transformer': Config.USE_VIT
},
'image_size': Config.IMAGE_SIZE,
'device': str(device),
'threshold': 0.5
})
if __name__ == '__main__':
print("=" * 60)
print("π DeepGuard - Deepfake Detection System")
print("=" * 60)
# Load model
load_model()
print("=" * 60)
port = int(os.environ.get("PORT", 7860))
print(f"π Starting server on http://0.0.0.0:{port}")
print("=" * 60)
app.run(debug=False, host='0.0.0.0', port=port)
|