Fake / app_o.py
eesfeg's picture
hoooollll
e880e5e
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras import layers, Model
import joblib
import cv2
import h5py
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
# ======================================================
# CONFIG
# ======================================================
IMG_SIZE = 224
# ======================================================
# CUSTOM LAYERS
# ======================================================
class SimpleMultiHeadAttention(layers.Layer):
def __init__(self, num_heads=8, key_dim=64, **kwargs):
super().__init__(**kwargs)
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
def call(self, x):
return self.mha(x, x)
def get_custom_objects():
return {
'SimpleMultiHeadAttention': SimpleMultiHeadAttention,
'MultiHeadAttention': layers.MultiHeadAttention,
'Dropout': layers.Dropout
}
# ======================================================
# FIX MISSING 'predictions' GROUP IN H5 FILE
# ======================================================
def fix_missing_predictions(h5_path):
try:
with h5py.File(h5_path, "r+") as f:
if "model_weights" not in f:
print("⚠️ H5 file has no 'model_weights' group β€” cannot fix this model.")
return
pred_path = "model_weights/predictions"
if pred_path in f:
return
grp = f.require_group(pred_path)
if "weight_names" not in grp.attrs:
grp.attrs.create("weight_names", [])
except Exception as e:
print("❌ Failed to edit H5:", e)
# ======================================================
# FALLBACK FEATURE EXTRACTOR
# ======================================================
def create_fallback_extractor():
base_model = tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3),
include_top=False,
weights='imagenet',
pooling='avg'
)
base_model.trainable = False
inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
features = base_model(x, training=False)
x = layers.Dense(512, activation='relu')(features)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation='relu')(x)
outputs = layers.Dense(512, activation='relu')(x)
return Model(inputs, outputs)
# ======================================================
# LOAD MODELS
# ======================================================
extractor, classifier = None, None
def load_models():
global extractor, classifier
# Load feature extractor
try:
fix_missing_predictions("hybrid_model.keras")
extractor = load_model("hybrid_model.keras", custom_objects=get_custom_objects(), compile=False)
print("βœ” Feature extractor loaded")
except Exception as e:
print(f"⚠ Failed to load extractor: {e}")
extractor = create_fallback_extractor()
print("βœ” Fallback extractor created")
# Load classifier
try:
classifier = joblib.load("gbdt_model.pkl")
print("βœ” Classifier loaded")
except Exception as e:
print(f"⚠ Failed to load classifier: {e}")
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
classifier = AdaBoostClassifier(
estimator=DecisionTreeClassifier(max_depth=3),
n_estimators=50,
random_state=40
)
dummy_features = np.random.randn(10, extractor.output_shape[-1])
dummy_labels = np.random.randint(0, 2, 10)
classifier.fit(dummy_features, dummy_labels)
print("βœ” Dummy classifier created")
# ======================================================
# IMAGE PREPROCESSING
# ======================================================
def preprocess_image(img: Image.Image):
img = np.array(img)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
img = img.astype("float32") / 255.0
if len(img.shape) == 2:
img = np.stack([img]*3, axis=-1)
return np.expand_dims(img, axis=0)
# ======================================================
# PREDICTION
# ======================================================
def predict_image(img: Image.Image):
img_pre = preprocess_image(img)
features = extractor.predict(img_pre, verbose=0).flatten().reshape(1, -1)
pred = classifier.predict(features)[0]
try:
proba = classifier.predict_proba(features)[0]
confidence = proba[pred] * 100
except:
confidence = 85.0
label = "Real" if pred == 0 else "Fake"
return {"label": label, "confidence": float(confidence)}
# ======================================================
# LIFESPAN + FASTAPI APP
# ======================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
print("⚑ Starting app and loading models...")
load_models()
yield
print("⚑ Shutting down app...")
app = FastAPI(title="Fake Image Detector API", lifespan=lifespan)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
# ROUTES
@app.get("/")
def root():
return {"message": "API is running!"}
@app.post("/predict/")
async def predict_endpoint(file: UploadFile = File(...)):
try:
img = Image.open(file.file).convert("RGB")
return JSONResponse(predict_image(img))
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=400)