eesfeg commited on
Commit
0220026
·
1 Parent(s): 899448a
Files changed (2) hide show
  1. app.py +58 -56
  2. custom_objects.py +1 -1
app.py CHANGED
@@ -6,59 +6,63 @@ import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
  from tensorflow.keras.models import load_model
9
- from tensorflow.keras import layers, Model
10
  import joblib
 
11
  import cv2
12
 
 
13
 
 
 
 
14
  IMG_SIZE = 224
15
- extractor, classifier = None, None
16
-
17
- # --- Custom Layer ---
18
- class SimpleMultiHeadAttention(layers.Layer):
19
- def __init__(self, num_heads=8, key_dim=64, **kwargs):
20
- super().__init__(**kwargs)
21
- self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
22
- def call(self, x):
23
- return self.mha(x, x)
24
-
25
- def get_custom_objects():
26
- return {"SimpleMultiHeadAttention": SimpleMultiHeadAttention, "MultiHeadAttention": layers.MultiHeadAttention}
27
 
28
- # --- Fallback extractor ---
 
 
29
  def create_fallback_extractor():
30
  base_model = tf.keras.applications.MobileNetV2(
31
  input_shape=(IMG_SIZE, IMG_SIZE, 3),
32
  include_top=False,
33
- weights='imagenet',
34
- pooling='avg'
35
  )
36
  base_model.trainable = False
37
  inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
38
  x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
39
  features = base_model(x, training=False)
40
- x = layers.Dense(512, activation='relu')(features)
41
- x = layers.Dropout(0.3)(x)
42
- x = layers.Dense(256, activation='relu')(x)
43
- outputs = layers.Dense(512, activation='relu')(x)
44
- return Model(inputs, outputs)
 
 
 
 
 
45
 
46
- # --- Safe model loading ---
47
  def load_models():
48
  global extractor, classifier
 
 
49
  try:
 
50
  extractor = load_model("hybrid_model.keras", custom_objects=get_custom_objects(), compile=False)
51
- print("✓ Hybrid extractor loaded")
52
  except Exception as e:
53
- print(f"✗ Failed to load hybrid_model.keras: {e}")
 
54
  extractor = create_fallback_extractor()
55
  print("✓ Fallback extractor created")
56
 
 
57
  try:
 
58
  classifier = joblib.load("gbdt_model.pkl")
59
- print("✓ Classifier loaded")
60
  except Exception as e:
61
- print(f"✗ Failed to load classifier: {e}")
62
  from sklearn.ensemble import AdaBoostClassifier
63
  from sklearn.tree import DecisionTreeClassifier
64
  classifier = AdaBoostClassifier(
@@ -66,21 +70,30 @@ def load_models():
66
  n_estimators=50,
67
  random_state=42
68
  )
 
69
  dummy_features = np.random.randn(10, extractor.output_shape[-1])
70
  dummy_labels = np.random.randint(0, 2, 10)
71
  classifier.fit(dummy_features, dummy_labels)
 
72
  print("✓ Dummy classifier created")
73
 
74
- # --- Image preprocessing ---
 
 
75
  def preprocess_image(img):
76
- img = np.array(img)
77
- if len(img.shape) == 2: # grayscale
 
78
  img = np.stack([img]*3, axis=-1)
 
 
79
  img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
80
  img = img.astype("float32") / 255.0
81
  return np.expand_dims(img, axis=0)
82
 
83
- # --- Prediction ---
 
 
84
  def predict(img):
85
  img_pre = preprocess_image(img)
86
  features = extractor.predict(img_pre, verbose=0).flatten().reshape(1, -1)
@@ -90,33 +103,22 @@ def predict(img):
90
  confidence = proba[pred]*100
91
  except:
92
  confidence = 85.0
93
- label = "Real" if pred==0 else "Fake"
94
- return {"label": label, "confidence": float(confidence)}
95
-
96
- # --- FastAPI ---
97
- # --- FastAPI ---
98
- from fastapi import FastAPI, UploadFile, File
99
- from fastapi.responses import JSONResponse
100
- from fastapi.middleware.cors import CORSMiddleware
101
- from PIL import Image
102
 
103
  # ======================================================
104
- # Define lifespan
105
  # ======================================================
106
- from contextlib import asynccontextmanager
107
- print("Loading models...")
108
- load_models()
109
- print("Models loaded successfully!")
110
-
111
- # Gradio interface
112
- import gradio as gr
113
 
114
- iface = gr.Interface(
115
- fn=predict,
116
- inputs=gr.Image(type="pil", label="Upload Image"),
117
- outputs=gr.Label(num_top_classes=2, label="Prediction"),
118
- title="Fake Image Detector",
119
- description="Upload an image to detect if it's Real or Fake."
120
- )
121
-
122
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
6
  from PIL import Image
7
  import tensorflow as tf
8
  from tensorflow.keras.models import load_model
 
9
  import joblib
10
+ import gradio as gr
11
  import cv2
12
 
13
+ from custom_objects import get_custom_objects # <- your custom_objects.py
14
 
15
+ # ======================================================
16
+ # CONFIG
17
+ # ======================================================
18
  IMG_SIZE = 224
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # ======================================================
21
+ # FALLBACK FEATURE EXTRACTOR
22
+ # ======================================================
23
  def create_fallback_extractor():
24
  base_model = tf.keras.applications.MobileNetV2(
25
  input_shape=(IMG_SIZE, IMG_SIZE, 3),
26
  include_top=False,
27
+ weights="imagenet",
28
+ pooling="avg"
29
  )
30
  base_model.trainable = False
31
  inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
32
  x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
33
  features = base_model(x, training=False)
34
+ x = tf.keras.layers.Dense(512, activation="relu")(features)
35
+ x = tf.keras.layers.Dropout(0.3)(x)
36
+ x = tf.keras.layers.Dense(256, activation="relu")(x)
37
+ outputs = tf.keras.layers.Dense(512, activation="relu")(x)
38
+ return tf.keras.Model(inputs, outputs)
39
+
40
+ # ======================================================
41
+ # LOAD MODELS
42
+ # ======================================================
43
+ extractor, classifier = None, None
44
 
 
45
  def load_models():
46
  global extractor, classifier
47
+
48
+ # Load feature extractor
49
  try:
50
+ print("Loading hybrid_model.keras ...")
51
  extractor = load_model("hybrid_model.keras", custom_objects=get_custom_objects(), compile=False)
52
+ print("✓ Feature extractor loaded")
53
  except Exception as e:
54
+ print(f"✗ Failed to load hybrid_model.keras ({e})")
55
+ print("Creating fallback extractor...")
56
  extractor = create_fallback_extractor()
57
  print("✓ Fallback extractor created")
58
 
59
+ # Load classifier
60
  try:
61
+ print("Loading classifier gbdt_model.pkl ...")
62
  classifier = joblib.load("gbdt_model.pkl")
63
+ print(f"✓ Classifier loaded ({type(classifier).__name__})")
64
  except Exception as e:
65
+ print(f"✗ Failed to load classifier ({e})")
66
  from sklearn.ensemble import AdaBoostClassifier
67
  from sklearn.tree import DecisionTreeClassifier
68
  classifier = AdaBoostClassifier(
 
70
  n_estimators=50,
71
  random_state=42
72
  )
73
+ # Dummy training
74
  dummy_features = np.random.randn(10, extractor.output_shape[-1])
75
  dummy_labels = np.random.randint(0, 2, 10)
76
  classifier.fit(dummy_features, dummy_labels)
77
+ joblib.dump(classifier, "classifier.pkl")
78
  print("✓ Dummy classifier created")
79
 
80
+ # ======================================================
81
+ # IMAGE PREPROCESSING
82
+ # ======================================================
83
  def preprocess_image(img):
84
+ if isinstance(img, Image.Image):
85
+ img = np.array(img)
86
+ if len(img.shape) == 2:
87
  img = np.stack([img]*3, axis=-1)
88
+ elif img.shape[2] == 3:
89
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
90
  img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
91
  img = img.astype("float32") / 255.0
92
  return np.expand_dims(img, axis=0)
93
 
94
+ # ======================================================
95
+ # PREDICTION FUNCTION
96
+ # ======================================================
97
  def predict(img):
98
  img_pre = preprocess_image(img)
99
  features = extractor.predict(img_pre, verbose=0).flatten().reshape(1, -1)
 
103
  confidence = proba[pred]*100
104
  except:
105
  confidence = 85.0
106
+ label = "Real" if pred == 0 else "Fake"
107
+ return {label: confidence}
 
 
 
 
 
 
 
108
 
109
  # ======================================================
110
+ # MAIN (Gradio)
111
  # ======================================================
112
+ if __name__ == "__main__":
113
+ print("Loading models...")
114
+ load_models()
115
+ print("Models loaded!")
 
 
 
116
 
117
+ iface = gr.Interface(
118
+ fn=predict,
119
+ inputs=gr.Image(type="pil", label="Upload Image"),
120
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
121
+ title="Fake Image Detector",
122
+ description="Upload an image to detect if it's Real or Fake."
123
+ )
124
+ iface.launch(server_name="0.0.0.0", server_port=7860)
 
custom_objects.py CHANGED
@@ -191,7 +191,7 @@ class SimpleMultiHeadAttention(layers.Layer):
191
 
192
  class FixedDropout(layers.Dropout):
193
  pass
194
-
195
 
196
  # ======================================================
197
  # RETURN ALL CUSTOM OBJECTS
 
191
 
192
  class FixedDropout(layers.Dropout):
193
  pass
194
+ # define a placeholder FixedDropout so H5 can load
195
 
196
  # ======================================================
197
  # RETURN ALL CUSTOM OBJECTS