eesfeg commited on
Commit
e880e5e
·
1 Parent(s): 79cd36b

hoooollll

Browse files
Files changed (2) hide show
  1. app.py +65 -88
  2. app_o.py +170 -0
app.py CHANGED
@@ -1,63 +1,37 @@
1
  import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
-
5
  import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
  from tensorflow.keras.models import load_model
9
  from tensorflow.keras import layers, Model
10
  import joblib
 
11
  import cv2
12
- import h5py
13
- from fastapi import FastAPI, UploadFile, File
14
- from fastapi.responses import JSONResponse
15
- from fastapi.middleware.cors import CORSMiddleware
16
- from contextlib import asynccontextmanager
17
-
18
  # ======================================================
19
  # CONFIG
20
  # ======================================================
21
  IMG_SIZE = 224
22
 
23
  # ======================================================
24
- # CUSTOM LAYERS
25
  # ======================================================
26
  class SimpleMultiHeadAttention(layers.Layer):
27
  def __init__(self, num_heads=8, key_dim=64, **kwargs):
28
  super().__init__(**kwargs)
29
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
30
-
31
  def call(self, x):
32
  return self.mha(x, x)
33
 
34
- def get_custom_objects():
35
- return {
36
- 'SimpleMultiHeadAttention': SimpleMultiHeadAttention,
37
- 'MultiHeadAttention': layers.MultiHeadAttention,
38
- 'Dropout': layers.Dropout
39
- }
40
 
41
  # ======================================================
42
- # FIX MISSING 'predictions' GROUP IN H5 FILE
43
- # ======================================================
44
- def fix_missing_predictions(h5_path):
45
- try:
46
- with h5py.File(h5_path, "r+") as f:
47
- if "model_weights" not in f:
48
- print("⚠️ H5 file has no 'model_weights' group — cannot fix this model.")
49
- return
50
- pred_path = "model_weights/predictions"
51
- if pred_path in f:
52
- return
53
- grp = f.require_group(pred_path)
54
- if "weight_names" not in grp.attrs:
55
- grp.attrs.create("weight_names", [])
56
- except Exception as e:
57
- print("❌ Failed to edit H5:", e)
58
-
59
- # ======================================================
60
- # FALLBACK FEATURE EXTRACTOR
61
  # ======================================================
62
  def create_fallback_extractor():
63
  base_model = tf.keras.applications.MobileNetV2(
@@ -67,6 +41,7 @@ def create_fallback_extractor():
67
  pooling='avg'
68
  )
69
  base_model.trainable = False
 
70
  inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
71
  x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
72
  features = base_model(x, training=False)
@@ -74,97 +49,99 @@ def create_fallback_extractor():
74
  x = layers.Dropout(0.3)(x)
75
  x = layers.Dense(256, activation='relu')(x)
76
  outputs = layers.Dense(512, activation='relu')(x)
77
- return Model(inputs, outputs)
 
 
78
 
79
  # ======================================================
80
  # LOAD MODELS
81
  # ======================================================
82
- extractor, classifier = None, None
83
-
84
  def load_models():
85
  global extractor, classifier
86
- # Load feature extractor
 
87
  try:
88
- fix_missing_predictions("hybrid_model.keras")
89
- extractor = load_model("hybrid_model.keras", custom_objects=get_custom_objects(), compile=False)
90
- print("✔ Feature extractor loaded")
 
 
 
 
 
91
  except Exception as e:
92
- print(f" Failed to load extractor: {e}")
 
93
  extractor = create_fallback_extractor()
94
- print(" Fallback extractor created")
 
95
  # Load classifier
96
  try:
 
97
  classifier = joblib.load("gbdt_model.pkl")
98
- print(" Classifier loaded")
99
  except Exception as e:
100
- print(f" Failed to load classifier: {e}")
101
  from sklearn.ensemble import AdaBoostClassifier
102
  from sklearn.tree import DecisionTreeClassifier
103
  classifier = AdaBoostClassifier(
104
  estimator=DecisionTreeClassifier(max_depth=3),
105
  n_estimators=50,
106
- random_state=40
107
  )
108
  dummy_features = np.random.randn(10, extractor.output_shape[-1])
109
  dummy_labels = np.random.randint(0, 2, 10)
110
  classifier.fit(dummy_features, dummy_labels)
111
- print("✔ Dummy classifier created")
 
112
 
113
  # ======================================================
114
  # IMAGE PREPROCESSING
115
  # ======================================================
116
- def preprocess_image(img: Image.Image):
117
- img = np.array(img)
 
 
 
118
  img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
119
  img = img.astype("float32") / 255.0
120
- if len(img.shape) == 2:
121
- img = np.stack([img]*3, axis=-1)
122
  return np.expand_dims(img, axis=0)
123
 
124
  # ======================================================
125
- # PREDICTION
126
  # ======================================================
127
- def predict_image(img: Image.Image):
128
- img_pre = preprocess_image(img)
129
- features = extractor.predict(img_pre, verbose=0).flatten().reshape(1, -1)
 
 
 
 
 
130
  pred = classifier.predict(features)[0]
131
  try:
132
  proba = classifier.predict_proba(features)[0]
133
  confidence = proba[pred] * 100
134
  except:
135
- confidence = 85.0
 
136
  label = "Real" if pred == 0 else "Fake"
137
- return {"label": label, "confidence": float(confidence)}
138
 
139
  # ======================================================
140
- # LIFESPAN + FASTAPI APP
141
  # ======================================================
142
- @asynccontextmanager
143
- async def lifespan(app: FastAPI):
144
- print("⚡ Starting app and loading models...")
145
  load_models()
146
- yield
147
- print("⚡ Shutting down app...")
148
-
149
- app = FastAPI(title="Fake Image Detector API", lifespan=lifespan)
150
-
151
- # CORS
152
- app.add_middleware(
153
- CORSMiddleware,
154
- allow_origins=["*"],
155
- allow_methods=["*"],
156
- allow_headers=["*"]
157
- )
158
-
159
- # ROUTES
160
- @app.get("/")
161
- def root():
162
- return {"message": "API is running!"}
163
-
164
- @app.post("/predict/")
165
- async def predict_endpoint(file: UploadFile = File(...)):
166
- try:
167
- img = Image.open(file.file).convert("RGB")
168
- return JSONResponse(predict_image(img))
169
- except Exception as e:
170
- return JSONResponse({"error": str(e)}, status_code=400)
 
1
  import os
 
 
 
2
  import numpy as np
3
  from PIL import Image
4
  import tensorflow as tf
5
  from tensorflow.keras.models import load_model
6
  from tensorflow.keras import layers, Model
7
  import joblib
8
+ import gradio as gr
9
  import cv2
10
+ from custom_objects import get_custom_objects
 
 
 
 
 
11
  # ======================================================
12
  # CONFIG
13
  # ======================================================
14
  IMG_SIZE = 224
15
 
16
  # ======================================================
17
+ # CUSTOM LAYERS (for H5 loading)
18
  # ======================================================
19
  class SimpleMultiHeadAttention(layers.Layer):
20
  def __init__(self, num_heads=8, key_dim=64, **kwargs):
21
  super().__init__(**kwargs)
22
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
23
+
24
  def call(self, x):
25
  return self.mha(x, x)
26
 
27
+ #def get_custom_objects():
28
+ # return {
29
+ # 'FixedDropout': layers.Dropout,
30
+ # 'MultiHeadAttention': layers.MultiHeadAttention,
31
+ # }
 
32
 
33
  # ======================================================
34
+ # FEATURE EXTRACTOR CREATION (fallback)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # ======================================================
36
  def create_fallback_extractor():
37
  base_model = tf.keras.applications.MobileNetV2(
 
41
  pooling='avg'
42
  )
43
  base_model.trainable = False
44
+
45
  inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
46
  x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
47
  features = base_model(x, training=False)
 
49
  x = layers.Dropout(0.3)(x)
50
  x = layers.Dense(256, activation='relu')(x)
51
  outputs = layers.Dense(512, activation='relu')(x)
52
+
53
+ model = Model(inputs, outputs)
54
+ return model
55
 
56
  # ======================================================
57
  # LOAD MODELS
58
  # ======================================================
 
 
59
  def load_models():
60
  global extractor, classifier
61
+
62
+ # Load extractor
63
  try:
64
+ print("Loading feature extractor (moodeli.h5)...")
65
+ with h5py.File("moodeli.h5", "r") as f:
66
+ print(f["model_weights"].keys())
67
+
68
+
69
+
70
+ extractor = load_model("moodeli.h5", custom_objects=get_custom_objects(), compile=False)
71
+ print("✓ Feature extractor loaded successfully")
72
  except Exception as e:
73
+ print(f" Failed to load H5 extractor: {str(e)[:200]}")
74
+ print("Creating fallback extractor...")
75
  extractor = create_fallback_extractor()
76
+ print(" Fallback extractor created")
77
+
78
  # Load classifier
79
  try:
80
+ print("Loading classifier (gbdt_model.pkl)...")
81
  classifier = joblib.load("gbdt_model.pkl")
82
+ print(f" Classifier loaded ({type(classifier).__name__})")
83
  except Exception as e:
84
+ print(f" Failed to load classifier: {str(e)[:200]}")
85
  from sklearn.ensemble import AdaBoostClassifier
86
  from sklearn.tree import DecisionTreeClassifier
87
  classifier = AdaBoostClassifier(
88
  estimator=DecisionTreeClassifier(max_depth=3),
89
  n_estimators=50,
90
+ random_state=42
91
  )
92
  dummy_features = np.random.randn(10, extractor.output_shape[-1])
93
  dummy_labels = np.random.randint(0, 2, 10)
94
  classifier.fit(dummy_features, dummy_labels)
95
+ joblib.dump(classifier, "classifier.pkl")
96
+ print("✓ Dummy classifier created and saved")
97
 
98
  # ======================================================
99
  # IMAGE PREPROCESSING
100
  # ======================================================
101
+ def preprocess_image(img):
102
+ if isinstance(img, Image.Image):
103
+ img = np.array(img)
104
+ if len(img.shape) == 3 and img.shape[2] == 3:
105
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
106
  img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
107
  img = img.astype("float32") / 255.0
 
 
108
  return np.expand_dims(img, axis=0)
109
 
110
  # ======================================================
111
+ # PREDICTION FUNCTION
112
  # ======================================================
113
+ def predict(image):
114
+ if image is None:
115
+ return [("No Image", 0.0)]
116
+
117
+ img_pre = preprocess_image(image)
118
+ features = extractor.predict(img_pre, verbose=0).flatten()
119
+ features = features.reshape(1, -1)
120
+
121
  pred = classifier.predict(features)[0]
122
  try:
123
  proba = classifier.predict_proba(features)[0]
124
  confidence = proba[pred] * 100
125
  except:
126
+ confidence = 85.0 # default
127
+
128
  label = "Real" if pred == 0 else "Fake"
129
+ return {label: confidence}
130
 
131
  # ======================================================
132
+ # MAIN (Hugging Face Spaces)
133
  # ======================================================
134
+ if __name__ == "__main__":
135
+ print("Loading models...")
 
136
  load_models()
137
+ print("Models loaded successfully!")
138
+
139
+ iface = gr.Interface(
140
+ fn=predict,
141
+ inputs=gr.Image(type="pil", label="📷 Upload Image"),
142
+ outputs=gr.Label(num_top_classes=2, label="🎯 Prediction"),
143
+ title="🔍 Fake Image Detector",
144
+ description="Upload an image to detect if it's Real or Fake."
145
+ )
146
+
147
+ iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_o.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ import tensorflow as tf
8
+ from tensorflow.keras.models import load_model
9
+ from tensorflow.keras import layers, Model
10
+ import joblib
11
+ import cv2
12
+ import h5py
13
+ from fastapi import FastAPI, UploadFile, File
14
+ from fastapi.responses import JSONResponse
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from contextlib import asynccontextmanager
17
+
18
+ # ======================================================
19
+ # CONFIG
20
+ # ======================================================
21
+ IMG_SIZE = 224
22
+
23
+ # ======================================================
24
+ # CUSTOM LAYERS
25
+ # ======================================================
26
+ class SimpleMultiHeadAttention(layers.Layer):
27
+ def __init__(self, num_heads=8, key_dim=64, **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
30
+
31
+ def call(self, x):
32
+ return self.mha(x, x)
33
+
34
+ def get_custom_objects():
35
+ return {
36
+ 'SimpleMultiHeadAttention': SimpleMultiHeadAttention,
37
+ 'MultiHeadAttention': layers.MultiHeadAttention,
38
+ 'Dropout': layers.Dropout
39
+ }
40
+
41
+ # ======================================================
42
+ # FIX MISSING 'predictions' GROUP IN H5 FILE
43
+ # ======================================================
44
+ def fix_missing_predictions(h5_path):
45
+ try:
46
+ with h5py.File(h5_path, "r+") as f:
47
+ if "model_weights" not in f:
48
+ print("⚠️ H5 file has no 'model_weights' group — cannot fix this model.")
49
+ return
50
+ pred_path = "model_weights/predictions"
51
+ if pred_path in f:
52
+ return
53
+ grp = f.require_group(pred_path)
54
+ if "weight_names" not in grp.attrs:
55
+ grp.attrs.create("weight_names", [])
56
+ except Exception as e:
57
+ print("❌ Failed to edit H5:", e)
58
+
59
+ # ======================================================
60
+ # FALLBACK FEATURE EXTRACTOR
61
+ # ======================================================
62
+ def create_fallback_extractor():
63
+ base_model = tf.keras.applications.MobileNetV2(
64
+ input_shape=(IMG_SIZE, IMG_SIZE, 3),
65
+ include_top=False,
66
+ weights='imagenet',
67
+ pooling='avg'
68
+ )
69
+ base_model.trainable = False
70
+ inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
71
+ x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
72
+ features = base_model(x, training=False)
73
+ x = layers.Dense(512, activation='relu')(features)
74
+ x = layers.Dropout(0.3)(x)
75
+ x = layers.Dense(256, activation='relu')(x)
76
+ outputs = layers.Dense(512, activation='relu')(x)
77
+ return Model(inputs, outputs)
78
+
79
+ # ======================================================
80
+ # LOAD MODELS
81
+ # ======================================================
82
+ extractor, classifier = None, None
83
+
84
+ def load_models():
85
+ global extractor, classifier
86
+ # Load feature extractor
87
+ try:
88
+ fix_missing_predictions("hybrid_model.keras")
89
+ extractor = load_model("hybrid_model.keras", custom_objects=get_custom_objects(), compile=False)
90
+ print("✔ Feature extractor loaded")
91
+ except Exception as e:
92
+ print(f"⚠ Failed to load extractor: {e}")
93
+ extractor = create_fallback_extractor()
94
+ print("✔ Fallback extractor created")
95
+ # Load classifier
96
+ try:
97
+ classifier = joblib.load("gbdt_model.pkl")
98
+ print("✔ Classifier loaded")
99
+ except Exception as e:
100
+ print(f"⚠ Failed to load classifier: {e}")
101
+ from sklearn.ensemble import AdaBoostClassifier
102
+ from sklearn.tree import DecisionTreeClassifier
103
+ classifier = AdaBoostClassifier(
104
+ estimator=DecisionTreeClassifier(max_depth=3),
105
+ n_estimators=50,
106
+ random_state=40
107
+ )
108
+ dummy_features = np.random.randn(10, extractor.output_shape[-1])
109
+ dummy_labels = np.random.randint(0, 2, 10)
110
+ classifier.fit(dummy_features, dummy_labels)
111
+ print("✔ Dummy classifier created")
112
+
113
+ # ======================================================
114
+ # IMAGE PREPROCESSING
115
+ # ======================================================
116
+ def preprocess_image(img: Image.Image):
117
+ img = np.array(img)
118
+ img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
119
+ img = img.astype("float32") / 255.0
120
+ if len(img.shape) == 2:
121
+ img = np.stack([img]*3, axis=-1)
122
+ return np.expand_dims(img, axis=0)
123
+
124
+ # ======================================================
125
+ # PREDICTION
126
+ # ======================================================
127
+ def predict_image(img: Image.Image):
128
+ img_pre = preprocess_image(img)
129
+ features = extractor.predict(img_pre, verbose=0).flatten().reshape(1, -1)
130
+ pred = classifier.predict(features)[0]
131
+ try:
132
+ proba = classifier.predict_proba(features)[0]
133
+ confidence = proba[pred] * 100
134
+ except:
135
+ confidence = 85.0
136
+ label = "Real" if pred == 0 else "Fake"
137
+ return {"label": label, "confidence": float(confidence)}
138
+
139
+ # ======================================================
140
+ # LIFESPAN + FASTAPI APP
141
+ # ======================================================
142
+ @asynccontextmanager
143
+ async def lifespan(app: FastAPI):
144
+ print("⚡ Starting app and loading models...")
145
+ load_models()
146
+ yield
147
+ print("⚡ Shutting down app...")
148
+
149
+ app = FastAPI(title="Fake Image Detector API", lifespan=lifespan)
150
+
151
+ # CORS
152
+ app.add_middleware(
153
+ CORSMiddleware,
154
+ allow_origins=["*"],
155
+ allow_methods=["*"],
156
+ allow_headers=["*"]
157
+ )
158
+
159
+ # ROUTES
160
+ @app.get("/")
161
+ def root():
162
+ return {"message": "API is running!"}
163
+
164
+ @app.post("/predict/")
165
+ async def predict_endpoint(file: UploadFile = File(...)):
166
+ try:
167
+ img = Image.open(file.file).convert("RGB")
168
+ return JSONResponse(predict_image(img))
169
+ except Exception as e:
170
+ return JSONResponse({"error": str(e)}, status_code=400)