File size: 5,897 Bytes
e880e5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras import layers, Model
import joblib
import cv2
import h5py
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
# ======================================================
# CONFIG
# ======================================================
IMG_SIZE = 224
# ======================================================
# CUSTOM LAYERS
# ======================================================
class SimpleMultiHeadAttention(layers.Layer):
def __init__(self, num_heads=8, key_dim=64, **kwargs):
super().__init__(**kwargs)
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
def call(self, x):
return self.mha(x, x)
def get_custom_objects():
return {
'SimpleMultiHeadAttention': SimpleMultiHeadAttention,
'MultiHeadAttention': layers.MultiHeadAttention,
'Dropout': layers.Dropout
}
# ======================================================
# FIX MISSING 'predictions' GROUP IN H5 FILE
# ======================================================
def fix_missing_predictions(h5_path):
try:
with h5py.File(h5_path, "r+") as f:
if "model_weights" not in f:
print("⚠️ H5 file has no 'model_weights' group — cannot fix this model.")
return
pred_path = "model_weights/predictions"
if pred_path in f:
return
grp = f.require_group(pred_path)
if "weight_names" not in grp.attrs:
grp.attrs.create("weight_names", [])
except Exception as e:
print("❌ Failed to edit H5:", e)
# ======================================================
# FALLBACK FEATURE EXTRACTOR
# ======================================================
def create_fallback_extractor():
base_model = tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3),
include_top=False,
weights='imagenet',
pooling='avg'
)
base_model.trainable = False
inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
features = base_model(x, training=False)
x = layers.Dense(512, activation='relu')(features)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation='relu')(x)
outputs = layers.Dense(512, activation='relu')(x)
return Model(inputs, outputs)
# ======================================================
# LOAD MODELS
# ======================================================
extractor, classifier = None, None
def load_models():
global extractor, classifier
# Load feature extractor
try:
fix_missing_predictions("hybrid_model.keras")
extractor = load_model("hybrid_model.keras", custom_objects=get_custom_objects(), compile=False)
print("✔ Feature extractor loaded")
except Exception as e:
print(f"⚠ Failed to load extractor: {e}")
extractor = create_fallback_extractor()
print("✔ Fallback extractor created")
# Load classifier
try:
classifier = joblib.load("gbdt_model.pkl")
print("✔ Classifier loaded")
except Exception as e:
print(f"⚠ Failed to load classifier: {e}")
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
classifier = AdaBoostClassifier(
estimator=DecisionTreeClassifier(max_depth=3),
n_estimators=50,
random_state=40
)
dummy_features = np.random.randn(10, extractor.output_shape[-1])
dummy_labels = np.random.randint(0, 2, 10)
classifier.fit(dummy_features, dummy_labels)
print("✔ Dummy classifier created")
# ======================================================
# IMAGE PREPROCESSING
# ======================================================
def preprocess_image(img: Image.Image):
img = np.array(img)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
img = img.astype("float32") / 255.0
if len(img.shape) == 2:
img = np.stack([img]*3, axis=-1)
return np.expand_dims(img, axis=0)
# ======================================================
# PREDICTION
# ======================================================
def predict_image(img: Image.Image):
img_pre = preprocess_image(img)
features = extractor.predict(img_pre, verbose=0).flatten().reshape(1, -1)
pred = classifier.predict(features)[0]
try:
proba = classifier.predict_proba(features)[0]
confidence = proba[pred] * 100
except:
confidence = 85.0
label = "Real" if pred == 0 else "Fake"
return {"label": label, "confidence": float(confidence)}
# ======================================================
# LIFESPAN + FASTAPI APP
# ======================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
print("⚡ Starting app and loading models...")
load_models()
yield
print("⚡ Shutting down app...")
app = FastAPI(title="Fake Image Detector API", lifespan=lifespan)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
# ROUTES
@app.get("/")
def root():
return {"message": "API is running!"}
@app.post("/predict/")
async def predict_endpoint(file: UploadFile = File(...)):
try:
img = Image.open(file.file).convert("RGB")
return JSONResponse(predict_image(img))
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=400)
|