eesfeg commited on
Commit
9499ccb
·
1 Parent(s): 4a92d80
Files changed (1) hide show
  1. app.py +102 -287
app.py CHANGED
@@ -3,334 +3,149 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
 
5
  import numpy as np
6
- from PIL import Image
7
  import tensorflow as tf
8
- from tensorflow.keras.models import load_model, Model
9
- from tensorflow.keras.layers import Input
 
 
 
 
 
10
  import joblib
11
  import gradio as gr
12
  import cv2
13
- import h5py
14
-
15
- from custom_objects import get_custom_objects
16
 
17
  # ======================================================
18
  # CONFIG
19
  # ======================================================
20
  IMG_SIZE = 224
 
 
21
 
22
- # ======================================================
23
- # DEBUG HYBRID MODEL
24
- # ======================================================
25
- def debug_hybrid_model():
26
- """Debug the hybrid_model.keras file"""
27
- print("\n🔍 Debugging hybrid_model_weights.h5...")
28
-
29
- try:
30
- # Method 1: Inspect the file directly
31
- print("Method 1: Inspecting HDF5 structure...")
32
- with h5py.File('hybrid_model_weights.h5', 'r') as f:
33
- print("Keys in file:", list(f.keys()))
34
- if 'model_weights' in f:
35
- print("Model weights groups:", list(f['model_weights'].keys()))
36
- except Exception as e:
37
- print(f"HDF5 inspection failed: {e}")
38
-
39
- # Method 2: Try to load with different approaches
40
- print("\nMethod 2: Trying different loading strategies...")
41
-
42
- # Strategy A: Load without custom objects first
43
- try:
44
- model = tf.keras.models.load_model('hybrid_model_weights.h5', compile=False)
45
- print("✓ Loaded without custom objects")
46
- return model
47
- except Exception as e:
48
- print(f"✗ Strategy A failed: {e}")
49
-
50
- # Strategy B: Try to rebuild from config
51
- try:
52
- print("\nTrying to rebuild from JSON config...")
53
- # Check if there's a JSON config
54
- with h5py.File('hybrid_model_weights.h5', 'r') as f:
55
- if 'model_config' in f:
56
- config = f['model_config'][()]
57
- config_str = config.decode('utf-8') if isinstance(config, bytes) else config
58
-
59
- # Try to load from JSON
60
- import json
61
- model_config = json.loads(config_str)
62
-
63
- # Try to create model from config
64
- model = tf.keras.models.model_from_json(
65
- config_str,
66
- custom_objects=get_custom_objects()
67
- )
68
-
69
- # Try to load weights
70
- model.load_weights('hybrid_model_weights.h5', by_name=True, skip_mismatch=True)
71
- print("✓ Rebuilt from config with custom objects")
72
- return model
73
- except Exception as e:
74
- print(f"✗ Strategy B failed: {e}")
75
-
76
- # Strategy C: Extract just the feature extraction part
77
- try:
78
- print("\nTrying to extract feature extractor submodel...")
79
- # Load the full model first
80
- full_model = tf.keras.models.load_model(
81
- 'hybrid_model_weights.h5',
82
- custom_objects=get_custom_objects(),
83
- compile=False
84
- )
85
-
86
- # Try to find the feature extractor layer
87
- # Common patterns for feature extractors
88
- layer_names = [layer.name for layer in full_model.layers]
89
- print(f"Available layers: {layer_names}")
90
-
91
- # Look for feature/dense/flatten layers
92
- feature_layer_names = []
93
- for name in layer_names:
94
- if 'feature' in name.lower() or 'dense' in name or 'flatten' in name or 'global' in name:
95
- feature_layer_names.append(name)
96
-
97
- if feature_layer_names:
98
- print(f"Potential feature layers: {feature_layer_names}")
99
- # Use the last dense layer before classification
100
- for layer_name in reversed(feature_layer_names):
101
- try:
102
- extractor = Model(
103
- inputs=full_model.input,
104
- outputs=full_model.get_layer(layer_name).output
105
- )
106
- print(f"✓ Created extractor from layer: {layer_name}")
107
- return extractor
108
- except:
109
- continue
110
-
111
- # If no specific layer found, try to remove classification layers
112
- # Assuming the model ends with Dense layers for classification
113
- for i, layer in enumerate(reversed(full_model.layers)):
114
- if isinstance(layer, tf.keras.layers.Dense) and layer.units <= 2: # Classification layer
115
- # Get output from layer before classification
116
- extractor = Model(
117
- inputs=full_model.input,
118
- outputs=full_model.layers[-i-2].output
119
- )
120
- print(f"✓ Created extractor by removing last {i+1} classification layers")
121
- return extractor
122
-
123
- except Exception as e:
124
- print(f"✗ Strategy C failed: {e}")
125
-
126
- return None
127
 
128
  # ======================================================
129
- # FALLBACK EXTRACTOR
130
  # ======================================================
131
- def create_fallback_extractor():
132
- """Create fallback extractor if hybrid model fails"""
133
- print("\nCreating fallback MobileNetV2 extractor...")
134
-
135
- base_model = tf.keras.applications.MobileNetV2(
136
- input_shape=(IMG_SIZE, IMG_SIZE, 3),
137
  include_top=False,
138
- weights="imagenet",
139
- pooling="avg"
 
140
  )
141
- base_model.trainable = False
142
-
143
- inputs = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
144
- x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
145
- features = base_model(x, training=False)
146
-
147
- # Add similar architecture to your hybrid model
148
- x = tf.keras.layers.Dense(512, activation="relu")(features)
149
- x = tf.keras.layers.Dropout(0.3)(x)
150
- x = tf.keras.layers.Dense(256, activation="relu")(x)
151
- x = tf.keras.layers.Dense(128, activation="relu")(x)
152
-
153
- model = Model(inputs, x, name="fallback_extractor")
154
- print(f"✓ Fallback extractor created. Output shape: {model.output_shape}")
155
- return model
 
 
 
 
 
 
 
156
 
157
  # ======================================================
158
  # LOAD MODELS
159
  # ======================================================
160
- extractor, classifier = None, None
161
-
162
  def load_models():
163
  global extractor, classifier
164
-
165
- print("\n" + "="*50)
166
- print("LOADING HYBRID MODEL")
167
- print("="*50)
168
-
169
- # 1. Try to load hybrid model with debugging
170
- extractor = debug_hybrid_model()
171
-
172
- if extractor is None:
173
- print("\n❌ Could not load hybrid_model.keras")
174
- print("Creating fallback extractor...")
175
- extractor = create_fallback_extractor()
176
- else:
177
- print(f"\n✅ Hybrid model loaded successfully!")
178
- print(f" Input shape: {extractor.input_shape}")
179
- print(f" Output shape: {extractor.output_shape}")
180
- print(f" Number of layers: {len(extractor.layers)}")
181
-
182
- # Test the extractor
183
- print("\n🧪 Testing extractor with random input...")
184
- test_input = np.random.randn(1, IMG_SIZE, IMG_SIZE, 3).astype(np.float32)
185
- test_output = extractor.predict(test_input, verbose=0)
186
- print(f" Test output shape: {test_output.shape}")
187
-
188
- # 2. Load classifier
189
- print("\n" + "="*50)
190
- print("LOADING CLASSIFIER")
191
- print("="*50)
192
-
193
- try:
194
- classifier_files = ["gbdt_model.pkl", "classifier.pkl", "rf_model.pkl"]
195
-
196
- for cf in classifier_files:
197
- if os.path.exists(cf):
198
- classifier = joblib.load(cf)
199
- print(f"✓ Loaded classifier: {cf}")
200
- print(f" Type: {type(classifier).__name__}")
201
-
202
- # Check if it's a pipeline
203
- if hasattr(classifier, 'steps'):
204
- print(f" Pipeline steps: {[name for name, _ in classifier.steps]}")
205
-
206
- # Test classifier
207
- if extractor is not None:
208
- output_dim = extractor.output_shape[-1]
209
- test_features = np.random.randn(1, output_dim)
210
- test_pred = classifier.predict(test_features)
211
- print(f" Test prediction: {test_pred[0]}")
212
- break
213
- except Exception as e:
214
- print(f"✗ Classifier loading failed: {e}")
215
-
216
- # Create simple fallback
217
- from sklearn.ensemble import RandomForestClassifier
218
- output_dim = extractor.output_shape[-1] if extractor else 128
219
- classifier = RandomForestClassifier(n_estimators=50, random_state=42)
220
- dummy_features = np.random.randn(100, output_dim)
221
- dummy_labels = np.random.randint(0, 2, 100)
222
- classifier.fit(dummy_features, dummy_labels)
223
- print("✓ Created fallback classifier")
224
-
225
- print("\n" + "="*50)
226
- print("MODELS READY FOR INFERENCE")
227
- print("="*50)
228
 
229
  # ======================================================
230
- # PREPROCESSING FOR HYBRID MODEL
231
  # ======================================================
232
  def preprocess_image(img):
233
- """Preprocess image for the hybrid model"""
234
- # Convert to numpy
235
- if isinstance(img, Image.Image):
236
- img = np.array(img)
237
-
238
- # Handle different formats
239
- if len(img.shape) == 2: # Grayscale
240
- img = np.stack([img] * 3, axis=-1)
241
- elif img.shape[2] == 4: # RGBA
242
  img = img[:, :, :3]
243
-
244
- # Convert to RGB if needed
245
- if img.shape[2] == 3:
246
- # Check if BGR (OpenCV)
247
- if img[0, 0, 0] > img[0, 0, 2]: # Blue > Red
248
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
249
-
250
- # Resize
251
  img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
252
-
253
- # Normalize to [0, 1] - common for custom models
254
  img = img.astype(np.float32) / 255.0
255
-
256
  return img
257
 
258
  # ======================================================
259
  # PREDICTION
260
  # ======================================================
261
  def predict(img):
262
- """Make prediction using hybrid model"""
 
 
 
 
 
 
 
263
  try:
264
- # Preprocess
265
- img_processed = preprocess_image(img)
266
- img_batch = np.expand_dims(img_processed, axis=0)
267
-
268
- # Extract features
269
- features = extractor.predict(img_batch, verbose=0)
270
-
271
- # Flatten if needed
272
- if len(features.shape) > 2:
273
- features = features.reshape(features.shape[0], -1)
274
-
275
- # Classify
276
- pred = classifier.predict(features)[0]
277
-
278
- # Get confidence
279
- try:
280
- proba = classifier.predict_proba(features)[0]
281
- confidence = proba[pred] * 100
282
- except:
283
- confidence = 80.0 # Default confidence
284
-
285
- # Return results
286
- label = "Real" if pred == 0 else "Fake"
287
- return {
288
- "Real": confidence if label == "Real" else 100 - confidence,
289
- "Fake": confidence if label == "Fake" else 100 - confidence
290
- }
291
-
292
- except Exception as e:
293
- print(f"Prediction error: {e}")
294
- return {"Real": 50.0, "Fake": 50.0}
295
 
296
  # ======================================================
297
- # CREATE INTERFACE
298
  # ======================================================
299
- def create_interface():
300
- """Create Gradio interface"""
301
- # Load models first
302
  load_models()
303
-
304
- # Create interface
305
  iface = gr.Interface(
306
  fn=predict,
307
- inputs=gr.Image(
308
- type="pil",
309
- label="Upload Image",
310
- image_mode="RGB"
311
- ),
312
- outputs=gr.Label(
313
- num_top_classes=2,
314
- label="Prediction"
315
- ),
316
- title="Hybrid Model Fake Image Detector",
317
- description="Using hybrid_model.keras + GBDT classifier",
318
- theme=gr.themes.Soft()
319
  )
320
-
321
- return iface
322
 
323
- # ======================================================
324
- # MAIN
325
- # ======================================================
326
  if __name__ == "__main__":
327
- print("\n🚀 Starting Hybrid Model Detector...")
328
-
329
- # Create and launch
330
- interface = create_interface()
331
-
332
- interface.launch(
333
- server_name="0.0.0.0",
334
- server_port=7860,
335
- share=False
336
- )
 
3
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
 
5
  import numpy as np
 
6
  import tensorflow as tf
7
+ from tensorflow.keras.applications import EfficientNetB7, InceptionResNetV2
8
+ from tensorflow.keras.layers import (
9
+ Input, GlobalAveragePooling2D,
10
+ Flatten, Concatenate
11
+ )
12
+ from tensorflow.keras.models import Model
13
+ from vit_keras import vit
14
  import joblib
15
  import gradio as gr
16
  import cv2
 
 
 
17
 
18
  # ======================================================
19
  # CONFIG
20
  # ======================================================
21
  IMG_SIZE = 224
22
+ WEIGHTS_PATH = "hybrid_model_weights.h5"
23
+ CLASSIFIER_PATH = "gbdt_model.pkl"
24
 
25
+ extractor = None
26
+ classifier = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # ======================================================
29
+ # BUILD HYBRID EXTRACTOR (MUST MATCH TRAINING)
30
  # ======================================================
31
+ def build_hybrid_extractor():
32
+ inputs = Input(shape=(IMG_SIZE, IMG_SIZE, 3), name="input_image")
33
+
34
+ eff = EfficientNetB7(
 
 
35
  include_top=False,
36
+ weights=None,
37
+ input_tensor=inputs,
38
+ name="EfficientNetB7_backbone"
39
  )
40
+
41
+ inc = InceptionResNetV2(
42
+ include_top=False,
43
+ weights=None,
44
+ input_tensor=inputs,
45
+ name="Inception_backbone"
46
+ )
47
+
48
+ vit_model = vit.vit_b16(
49
+ image_size=IMG_SIZE,
50
+ pretrained=False,
51
+ include_top=False
52
+ )
53
+ vit_out = vit_model(inputs)
54
+
55
+ f1 = GlobalAveragePooling2D(name="eff_gap")(eff.output)
56
+ f2 = GlobalAveragePooling2D(name="inc_gap")(inc.output)
57
+ f3 = Flatten(name="vit_flat")(vit_out)
58
+
59
+ features = Concatenate(name="merged_features")([f1, f2, f3])
60
+
61
+ return Model(inputs, features, name="HybridExtractor")
62
 
63
  # ======================================================
64
  # LOAD MODELS
65
  # ======================================================
 
 
66
  def load_models():
67
  global extractor, classifier
68
+
69
+ print("\n🚀 Loading Hybrid Extractor...")
70
+
71
+ if not os.path.exists(WEIGHTS_PATH):
72
+ raise FileNotFoundError(f"{WEIGHTS_PATH} not found")
73
+
74
+ extractor = build_hybrid_extractor()
75
+ extractor.load_weights(WEIGHTS_PATH, by_name=True)
76
+
77
+ # Sanity check
78
+ dummy = tf.zeros((1, IMG_SIZE, IMG_SIZE, 3))
79
+ feat = extractor(dummy)
80
+ print("✓ Extractor loaded | Feature dim:", feat.shape[-1])
81
+
82
+ print("\n🚀 Loading Classifier...")
83
+
84
+ if not os.path.exists(CLASSIFIER_PATH):
85
+ raise FileNotFoundError(f"{CLASSIFIER_PATH} not found")
86
+
87
+ classifier = joblib.load(CLASSIFIER_PATH)
88
+
89
+ # Verify compatibility
90
+ test_feat = np.random.randn(1, feat.shape[-1])
91
+ classifier.predict(test_feat)
92
+
93
+ print(" Classifier loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # ======================================================
96
+ # PREPROCESS IMAGE
97
  # ======================================================
98
  def preprocess_image(img):
99
+ if img is None:
100
+ raise ValueError("No image provided")
101
+
102
+ if img.shape[-1] == 4:
 
 
 
 
 
103
  img = img[:, :, :3]
104
+
 
 
 
 
 
 
 
105
  img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
 
 
106
  img = img.astype(np.float32) / 255.0
 
107
  return img
108
 
109
  # ======================================================
110
  # PREDICTION
111
  # ======================================================
112
  def predict(img):
113
+ img = preprocess_image(img)
114
+ img = np.expand_dims(img, axis=0)
115
+
116
+ features = extractor.predict(img, verbose=0)
117
+ features = features.reshape(features.shape[0], -1)
118
+
119
+ pred = classifier.predict(features)[0]
120
+
121
  try:
122
+ proba = classifier.predict_proba(features)[0]
123
+ confidence = float(np.max(proba)) * 100
124
+ except:
125
+ confidence = 80.0
126
+
127
+ label = "Fake" if pred == 1 else "Real"
128
+
129
+ return {
130
+ "Real": confidence if label == "Real" else 100 - confidence,
131
+ "Fake": confidence if label == "Fake" else 100 - confidence
132
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  # ======================================================
135
+ # GRADIO APP
136
  # ======================================================
137
+ def main():
 
 
138
  load_models()
139
+
 
140
  iface = gr.Interface(
141
  fn=predict,
142
+ inputs=gr.Image(type="numpy", label="Upload Image"),
143
+ outputs=gr.Label(num_top_classes=2),
144
+ title="Hybrid Fake Image Detector",
145
+ description="EfficientNet + Inception + ViT + GBDT"
 
 
 
 
 
 
 
 
146
  )
 
 
147
 
148
+ iface.launch(server_name="0.0.0.0", server_port=7860)
149
+
 
150
  if __name__ == "__main__":
151
+ main()