Harshasnade's picture
Disable history and cleanup immediately
2a8b4dc
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import sys
import os
# Add model directory to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'model')))
import datetime
import torch
import cv2
import os
import numpy as np
import ssl
import base64
from werkzeug.utils import secure_filename
import io
from PIL import Image
from src import video_inference
# Disable SSL verification
ssl._create_default_https_context = ssl._create_unverified_context
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations.pytorch import ToTensorV2
from src.models import DeepfakeDetector
from src.config import Config
import database
safetensors_import_error = None
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
print("βœ… safetensors library loaded successfully")
except ImportError as e:
SAFETENSORS_AVAILABLE = False
safetensors_import_error = str(e)
print(f"❌ Failed to import safetensors: {e}")
app = Flask(__name__, static_folder='../frontend', static_url_path='')
CORS(app, resources={r"/*": {"origins": "*"}})
# Configuration
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), 'uploads')
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp', 'mp4', 'avi', 'mov', 'webm'}
HISTORY_FOLDER = os.path.join(os.path.dirname(__file__), 'history_uploads')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(HISTORY_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# Global model, transform, and error state
device = torch.device(Config.DEVICE)
model = None
transform = None
loading_error = None
def get_transform():
return A.Compose([
A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def load_model():
"""Load the trained deepfake detection model"""
global model, transform, loading_error
checkpoint_dir = Config.CHECKPOINT_DIR
# Explicitly target the model requested by the user
target_model_name = "best_model.safetensors"
checkpoint_path = os.path.join(checkpoint_dir, target_model_name)
print(f"Using device: {device}")
# Initialize with pretrained=True to ensure missing keys (frozen layers) have valid ImageNet weights
# instead of random noise. This fixes the "random prediction" issue when the checkpoint
# only contains finetuned layers.
try:
model = DeepfakeDetector(pretrained=True)
model.to(device)
model.eval()
except Exception as e:
loading_error = f"Failed to init model architecture: {str(e)}"
print(loading_error)
model = None
return None, None
# Check if file exists first
if not os.path.exists(checkpoint_path):
loading_error = f"File not found: {checkpoint_path}. Contents of {checkpoint_dir}: {os.listdir(checkpoint_dir) if os.path.exists(checkpoint_dir) else 'Dir missing'}"
print(f"❌ {loading_error}")
model = None
transform = get_transform()
return model, transform
try:
print(f"Loading checkpoint: {checkpoint_path}")
if checkpoint_path.endswith(".safetensors"):
if SAFETENSORS_AVAILABLE:
state_dict = load_file(checkpoint_path)
else:
# Fallback to torch.load even for safetensors if they are actually pickles
# PyTorch 2.6+ requires weights_only=False for legacy pickles
print("WARNING: safetensors not found, attempting torch.load with weights_only=False")
# If we are failing here, it's likely because we couldn't import safetensors.
# Let's save that info.
loading_error = f"Safetensors import failed: {safetensors_import_error}. Fallback torch.load failed."
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
else:
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"βœ… Model loaded successfully!")
loading_error = None # Clear error on success
except Exception as e:
loading_error = f"Error loading checkpoint: {str(e)}"
if safetensors_import_error:
loading_error += f" | NOTE: Safetensors lib failed to import: {safetensors_import_error}"
print(f"❌ {loading_error}")
model = None
transform = get_transform()
return model, transform
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def predict_image(image_path):
"""Make prediction on a single image"""
if model is None:
return None, f"Model Error: {loading_error}"
try:
# Read and preprocess image
image = cv2.imread(image_path)
if image is None:
return None, "Error: Could not read image"
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
augmented = transform(image=image)
image_tensor = augmented['image'].unsqueeze(0).to(device)
# Make prediction
logits = model(image_tensor)
prob = torch.sigmoid(logits).item()
# Generate Heatmap
heatmap = model.get_heatmap(image_tensor)
# Process Heatmap for Visualization
# Resize to original image size
heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Superimpose
# Heatmap is BGR (from cv2), Image is RGB. Convert Image to BGR.
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
superimposed_img = heatmap * 0.4 + image_bgr * 0.6
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
# Encode to Base64
_, buffer = cv2.imencode('.jpg', superimposed_img)
heatmap_b64 = base64.b64encode(buffer).decode('utf-8')
is_fake = prob > 0.5
label = "FAKE" if is_fake else "REAL"
confidence = prob if is_fake else 1 - prob
return {
'prediction': label,
'confidence': float(confidence),
'fake_probability': float(prob),
'real_probability': float(1 - prob),
'heatmap': heatmap_b64
}, None
except Exception as e:
return None, str(e)
@app.route('/')
def index():
"""Serve the simple demo frontend"""
return send_from_directory('static', 'demo.html')
@app.route('/history_uploads/<path:filename>')
def serve_history_image(filename):
"""Serve history images"""
return send_from_directory(HISTORY_FOLDER, filename)
@app.route('/api/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'model_loaded': model is not None,
'device': str(device)
})
@app.route('/api/predict', methods=['POST'])
def predict():
"""Handle image upload and prediction"""
try:
# Check if file is present
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Invalid file type. Allowed types: png, jpg, jpeg, webp'}), 400
# Save file
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Make prediction
result, error = predict_image(filepath)
if result is None:
return jsonify({'error': error}), 500
# Cleanup - Delete the upload immediately
try:
if os.path.exists(filepath):
os.remove(filepath)
except:
pass
return jsonify(result)
# Clean up uploaded file
try:
os.remove(filepath)
except:
pass
return jsonify(result)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/predict_video', methods=['POST'])
def predict_video():
"""Handle video upload and prediction"""
try:
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Invalid file type'}), 400
# Save file
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Process Video
# Note: process_video needs sys.path to be correct to import models inside it if it was standalone,
# but here we pass the already loaded 'model' object.
if model is None:
return jsonify({'error': 'Model not loaded'}), 500
result = video_inference.process_video(filepath, model, transform, device)
if "error" in result:
return jsonify(result), 500
# Save to History (Using the first frame or a placeholder icon for now?)
# For video, we might want to save the video file itself to history_uploads
# or just a thumbnail. Let's save the video for now.
import shutil
history_filename = f"scan_{int(datetime.datetime.now().timestamp())}_{filename}"
history_path = os.path.join(HISTORY_FOLDER, history_filename)
shutil.copy(filepath, history_path)
relative_path = f"history_uploads/{history_filename}"
# Add to database
# Note: The database 'add_scan' might expect image-specific fields.
# We'll re-use 'fake_prob' as 'avg_fake_prob'
database.add_scan(
filename=filename,
prediction=result['prediction'],
confidence=result['confidence'],
fake_prob=result['avg_fake_prob'],
real_prob=1 - result['avg_fake_prob'],
image_path=relative_path
)
# Clean up
try:
os.remove(filepath)
except:
pass
# Add video URL for frontend playback
result['video_url'] = relative_path
return jsonify(result)
except Exception as e:
print(f"Video Error: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/history', methods=['GET'])
def get_history():
"""Get all past scans"""
history = database.get_history()
history = database.get_history()
return jsonify(history)
@app.route('/api/history/<int:scan_id>', methods=['DELETE'])
def delete_scan(scan_id):
"""Delete a specific scan"""
if database.delete_scan(scan_id):
return jsonify({'message': 'Scan deleted'})
return jsonify({'error': 'Failed to delete scan'}), 500
@app.route('/api/history', methods=['DELETE'])
def clear_history():
"""Clear all history"""
if database.clear_history():
return jsonify({'message': 'History cleared'})
return jsonify({'error': 'Failed to clear history'}), 500
@app.route('/api/model-info', methods=['GET'])
def model_info():
"""Return model information"""
return jsonify({
'model_name': 'DeepGuard: Advanced Deepfake Detector',
'architecture': 'Hybrid CNN-ViT',
'components': {
'RGB Analysis': Config.USE_RGB,
'Frequency Domain': Config.USE_FREQ,
'Patch-based Detection': Config.USE_PATCH,
'Vision Transformer': Config.USE_VIT
},
'image_size': Config.IMAGE_SIZE,
'device': str(device),
'threshold': 0.5
})
if __name__ == '__main__':
print("=" * 60)
print("πŸš€ DeepGuard - Deepfake Detection System")
print("=" * 60)
# Load model
load_model()
print("=" * 60)
port = int(os.environ.get("PORT", 7860))
print(f"🌐 Starting server on http://0.0.0.0:{port}")
print("=" * 60)
app.run(debug=False, host='0.0.0.0', port=port)