File size: 6,702 Bytes
3cce3e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import joblib
import numpy as np
import tensorflow as tf
from keras.models import load_model
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import io


class FakeImageDetector:
    def __init__(self, feature_extractor_path="hybrid_model_weights.h5", 
                 classifier_path="gbdt_model.pkl", img_size=224):
        self.img_size = img_size
        self.feature_extractor = None
        self.classifier = None
        self.load_models(feature_extractor_path, classifier_path)
    
    def load_models(self, feature_extractor_path, classifier_path):
        """Load pre-trained models"""
        print("🔄 Loading models...")
        self.feature_extractor = load_model(feature_extractor_path, compile=False)
        self.classifier = joblib.load(classifier_path)
        print("✅ Models loaded successfully")
    
    def preprocess_image(self, image_input):
        """
        Preprocess image from various inputs
        Supports: file path, numpy array, PIL Image
        """
        try:
            # Handle different input types
            if isinstance(image_input, str):  # File path
                img = cv2.imread(image_input)
                if img is None:
                    raise ValueError(f"Could not read image from {image_input}")
            elif isinstance(image_input, np.ndarray):  # Numpy array
                img = image_input.copy()
                if len(img.shape) == 2:  # Grayscale
                    img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
                elif img.shape[2] == 4:  # RGBA
                    img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
            elif hasattr(image_input, 'read'):  # File-like object
                img_array = np.frombuffer(image_input.read(), np.uint8)
                img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
            else:  # Assume PIL Image
                img = cv2.cvtColor(np.array(image_input), cv2.COLOR_RGB2BGR)
            
            # Convert to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Resize
            img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_AREA)
            img = img.astype('float32') / 255.0
            
            return img
        except Exception as e:
            print(f"Error preprocessing image: {str(e)}")
            return None
    
    def extract_features(self, image):
        """Extract features from preprocessed image"""
        image_batch = np.expand_dims(image, axis=0)
        features = self.feature_extractor.predict(image_batch, verbose=0)
        return features
    
    def predict(self, image_input, return_confidence=False):
        """
        Predict if image is real or fake
        
        Args:
            image_input: Can be file path, numpy array, or PIL Image
            return_confidence: If True, returns confidence score too
        
        Returns:
            prediction (0=Real, 1=Fake) and optional confidence
        """
        # Preprocess image
        processed_image = self.preprocess_image(image_input)
        if processed_image is None:
            return None
        
        # Extract features
        features = self.extract_features(processed_image)
        
        # Make prediction
        prediction = self.classifier.predict(features)[0]
        confidence = self.classifier.predict_proba(features)[0][prediction] * 100
        
        result = {
            'prediction': 'Real' if prediction == 0 else 'Fake',
            'confidence': confidence,
            'prediction_code': int(prediction)
        }
        
        if return_confidence:
            return result
        else:
            return result['prediction']
    
    def predict_batch(self, image_paths):
        """Predict batch of images"""
        results = []
        for img_path in image_paths:
            result = self.predict(img_path, return_confidence=True)
            if result:
                results.append({
                    'image_path': img_path,
                    'prediction': result['prediction'],
                    'confidence': result['confidence']
                })
        return results
    
    def visualize_prediction(self, image_input, save_path=None):
        """Make prediction and visualize result"""
        result = self.predict(image_input, return_confidence=True)
        if result is None:
            print("❌ Could not process image")
            return
        
        # Load image for display
        if isinstance(image_input, str):
            display_img = cv2.imread(image_input)
            display_img = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB)
        else:
            display_img = self.preprocess_image(image_input)
            display_img = (display_img * 255).astype(np.uint8)
        
        # Create visualization
        plt.figure(figsize=(8, 6))
        plt.imshow(display_img)
        plt.title(f"Prediction: {result['prediction']} ({result['confidence']:.1f}%)", 
                 fontsize=16, pad=20)
        
        # Color code based on prediction
        color = 'green' if result['prediction'] == 'Real' else 'red'
        plt.gca().text(0.5, -0.1, 
                      f"Confidence: {result['confidence']:.1f}%", 
                      ha='center', va='center', 
                      transform=plt.gca().transAxes,
                      fontsize=12,
                      bbox=dict(boxstyle="round,pad=0.3", 
                               facecolor=color, 
                               alpha=0.5))
        
        plt.axis('off')
        
        if save_path:
            plt.savefig(save_path, dpi=100, bbox_inches='tight')
            print(f"✅ Visualization saved to {save_path}")
        
        plt.show()
        return result


# Example usage
if __name__ == "__main__":
    # Initialize detector
    detector = FakeImageDetector()
    
    # Example 1: Predict from file path
    if os.path.exists("real_0.jpg"):
        result = detector.predict("fake_0.jpg", return_confidence=True)
        print(f"Prediction: {result['prediction']} ({result['confidence']:.1f}%)")
        
        # Visualize
        detector.visualize_prediction("fake_0.jpg")
    
    # Example 2: Predict from URL (requires additional libraries)
    # import requests
    # response = requests.get("https://example.com/image.jpg")
    # result = detector.predict(io.BytesIO(response.content))
    
    # Example 3: Batch prediction
    # image_list = ["img1.jpg", "img2.jpg", "img3.jpg"]
    # results = detector.predict_batch(image_list)
    # for res in results:
    #     print(f"{res['image_path']}: {res['prediction']} ({res['confidence']:.1f}%)")