Fake / inference.py
eesfeg's picture
tg
3cce3e9
import os
import joblib
import numpy as np
import tensorflow as tf
from keras.models import load_model
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import io
class FakeImageDetector:
def __init__(self, feature_extractor_path="hybrid_model_weights.h5",
classifier_path="gbdt_model.pkl", img_size=224):
self.img_size = img_size
self.feature_extractor = None
self.classifier = None
self.load_models(feature_extractor_path, classifier_path)
def load_models(self, feature_extractor_path, classifier_path):
"""Load pre-trained models"""
print("🔄 Loading models...")
self.feature_extractor = load_model(feature_extractor_path, compile=False)
self.classifier = joblib.load(classifier_path)
print("✅ Models loaded successfully")
def preprocess_image(self, image_input):
"""
Preprocess image from various inputs
Supports: file path, numpy array, PIL Image
"""
try:
# Handle different input types
if isinstance(image_input, str): # File path
img = cv2.imread(image_input)
if img is None:
raise ValueError(f"Could not read image from {image_input}")
elif isinstance(image_input, np.ndarray): # Numpy array
img = image_input.copy()
if len(img.shape) == 2: # Grayscale
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # RGBA
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
elif hasattr(image_input, 'read'): # File-like object
img_array = np.frombuffer(image_input.read(), np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
else: # Assume PIL Image
img = cv2.cvtColor(np.array(image_input), cv2.COLOR_RGB2BGR)
# Convert to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Resize
img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_AREA)
img = img.astype('float32') / 255.0
return img
except Exception as e:
print(f"Error preprocessing image: {str(e)}")
return None
def extract_features(self, image):
"""Extract features from preprocessed image"""
image_batch = np.expand_dims(image, axis=0)
features = self.feature_extractor.predict(image_batch, verbose=0)
return features
def predict(self, image_input, return_confidence=False):
"""
Predict if image is real or fake
Args:
image_input: Can be file path, numpy array, or PIL Image
return_confidence: If True, returns confidence score too
Returns:
prediction (0=Real, 1=Fake) and optional confidence
"""
# Preprocess image
processed_image = self.preprocess_image(image_input)
if processed_image is None:
return None
# Extract features
features = self.extract_features(processed_image)
# Make prediction
prediction = self.classifier.predict(features)[0]
confidence = self.classifier.predict_proba(features)[0][prediction] * 100
result = {
'prediction': 'Real' if prediction == 0 else 'Fake',
'confidence': confidence,
'prediction_code': int(prediction)
}
if return_confidence:
return result
else:
return result['prediction']
def predict_batch(self, image_paths):
"""Predict batch of images"""
results = []
for img_path in image_paths:
result = self.predict(img_path, return_confidence=True)
if result:
results.append({
'image_path': img_path,
'prediction': result['prediction'],
'confidence': result['confidence']
})
return results
def visualize_prediction(self, image_input, save_path=None):
"""Make prediction and visualize result"""
result = self.predict(image_input, return_confidence=True)
if result is None:
print("❌ Could not process image")
return
# Load image for display
if isinstance(image_input, str):
display_img = cv2.imread(image_input)
display_img = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB)
else:
display_img = self.preprocess_image(image_input)
display_img = (display_img * 255).astype(np.uint8)
# Create visualization
plt.figure(figsize=(8, 6))
plt.imshow(display_img)
plt.title(f"Prediction: {result['prediction']} ({result['confidence']:.1f}%)",
fontsize=16, pad=20)
# Color code based on prediction
color = 'green' if result['prediction'] == 'Real' else 'red'
plt.gca().text(0.5, -0.1,
f"Confidence: {result['confidence']:.1f}%",
ha='center', va='center',
transform=plt.gca().transAxes,
fontsize=12,
bbox=dict(boxstyle="round,pad=0.3",
facecolor=color,
alpha=0.5))
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=100, bbox_inches='tight')
print(f"✅ Visualization saved to {save_path}")
plt.show()
return result
# Example usage
if __name__ == "__main__":
# Initialize detector
detector = FakeImageDetector()
# Example 1: Predict from file path
if os.path.exists("real_0.jpg"):
result = detector.predict("fake_0.jpg", return_confidence=True)
print(f"Prediction: {result['prediction']} ({result['confidence']:.1f}%)")
# Visualize
detector.visualize_prediction("fake_0.jpg")
# Example 2: Predict from URL (requires additional libraries)
# import requests
# response = requests.get("https://example.com/image.jpg")
# result = detector.predict(io.BytesIO(response.content))
# Example 3: Batch prediction
# image_list = ["img1.jpg", "img2.jpg", "img3.jpg"]
# results = detector.predict_batch(image_list)
# for res in results:
# print(f"{res['image_path']}: {res['prediction']} ({res['confidence']:.1f}%)")