import os os.environ["CUDA_VISIBLE_DEVICES"] = "-1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import numpy as np from PIL import Image import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras import layers, Model import joblib import cv2 import h5py from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager # ====================================================== # CONFIG # ====================================================== IMG_SIZE = 224 # ====================================================== # CUSTOM LAYERS # ====================================================== class SimpleMultiHeadAttention(layers.Layer): def __init__(self, num_heads=8, key_dim=64, **kwargs): super().__init__(**kwargs) self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim) def call(self, x): return self.mha(x, x) def get_custom_objects(): return { 'SimpleMultiHeadAttention': SimpleMultiHeadAttention, 'MultiHeadAttention': layers.MultiHeadAttention, 'Dropout': layers.Dropout } # ====================================================== # FIX MISSING 'predictions' GROUP IN H5 FILE # ====================================================== def fix_missing_predictions(h5_path): try: with h5py.File(h5_path, "r+") as f: if "model_weights" not in f: print("⚠️ H5 file has no 'model_weights' group — cannot fix this model.") return pred_path = "model_weights/predictions" if pred_path in f: return grp = f.require_group(pred_path) if "weight_names" not in grp.attrs: grp.attrs.create("weight_names", []) except Exception as e: print("❌ Failed to edit H5:", e) # ====================================================== # FALLBACK FEATURE EXTRACTOR # ====================================================== def create_fallback_extractor(): base_model = tf.keras.applications.MobileNetV2( input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet', pooling='avg' ) base_model.trainable = False inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3)) x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs) features = base_model(x, training=False) x = layers.Dense(512, activation='relu')(features) x = layers.Dropout(0.3)(x) x = layers.Dense(256, activation='relu')(x) outputs = layers.Dense(512, activation='relu')(x) return Model(inputs, outputs) # ====================================================== # LOAD MODELS # ====================================================== extractor, classifier = None, None def load_models(): global extractor, classifier # Load feature extractor try: fix_missing_predictions("hybrid_model.keras") extractor = load_model("hybrid_model.keras", custom_objects=get_custom_objects(), compile=False) print("✔ Feature extractor loaded") except Exception as e: print(f"⚠ Failed to load extractor: {e}") extractor = create_fallback_extractor() print("✔ Fallback extractor created") # Load classifier try: classifier = joblib.load("gbdt_model.pkl") print("✔ Classifier loaded") except Exception as e: print(f"⚠ Failed to load classifier: {e}") from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier classifier = AdaBoostClassifier( estimator=DecisionTreeClassifier(max_depth=3), n_estimators=50, random_state=40 ) dummy_features = np.random.randn(10, extractor.output_shape[-1]) dummy_labels = np.random.randint(0, 2, 10) classifier.fit(dummy_features, dummy_labels) print("✔ Dummy classifier created") # ====================================================== # IMAGE PREPROCESSING # ====================================================== def preprocess_image(img: Image.Image): img = np.array(img) img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) img = img.astype("float32") / 255.0 if len(img.shape) == 2: img = np.stack([img]*3, axis=-1) return np.expand_dims(img, axis=0) # ====================================================== # PREDICTION # ====================================================== def predict_image(img: Image.Image): img_pre = preprocess_image(img) features = extractor.predict(img_pre, verbose=0).flatten().reshape(1, -1) pred = classifier.predict(features)[0] try: proba = classifier.predict_proba(features)[0] confidence = proba[pred] * 100 except: confidence = 85.0 label = "Real" if pred == 0 else "Fake" return {"label": label, "confidence": float(confidence)} # ====================================================== # LIFESPAN + FASTAPI APP # ====================================================== @asynccontextmanager async def lifespan(app: FastAPI): print("⚡ Starting app and loading models...") load_models() yield print("⚡ Shutting down app...") app = FastAPI(title="Fake Image Detector API", lifespan=lifespan) # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] ) # ROUTES @app.get("/") def root(): return {"message": "API is running!"} @app.post("/predict/") async def predict_endpoint(file: UploadFile = File(...)): try: img = Image.open(file.file).convert("RGB") return JSONResponse(predict_image(img)) except Exception as e: return JSONResponse({"error": str(e)}, status_code=400)