ai-detector-backend / kode_api.py
Alstears's picture
Upload 135 files
2fe8f88 verified
import io
import os
import warnings
import numpy as np
import timm
import torch
import torch.nn.functional as F
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torchvision.transforms as transforms
warnings.filterwarnings("ignore")
# =========================
# CONFIG
# =========================
APP_TITLE = "AI Forensic Detector API"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FEEDBACK_DIR = "feedback"
ALLOWED_EXT = {"png", "jpg", "jpeg", "webp"}
# Daftar file model ensemble Anda
MODEL_FILES = ["ckpt_best_v4_epoch8.pth", "ckpt_best_v4_epoch14.pth"]
ROOT_DIR = "." # Silakan sesuaikan folder tempat menyimpan .pth jika berbeda
# =========================
# APP INIT
# =========================
app = FastAPI(title=APP_TITLE)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
os.makedirs(os.path.join(FEEDBACK_DIR, "real"), exist_ok=True)
os.makedirs(os.path.join(FEEDBACK_DIR, "fake"), exist_ok=True)
# =========================
# GLOBAL MODELS & TRANSFORMS
# =========================
models_ensemble = []
# Menggunakan transformasi standard v4 Anda (Pastikan mean/std sesuai training)
val_tf = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# =========================
# LOAD MODELS (STRICT ENSEMBLE)
# =========================
def load_ensemble_models():
global models_ensemble
models_ensemble = []
print(f"🖥️ DEVICE YANG DIGUNAKAN: {DEVICE}")
print("⏳ Memuat semua model ensemble...")
for f_name in MODEL_FILES:
path = os.path.join(ROOT_DIR, f_name)
abs_path = os.path.abspath(path)
if not os.path.exists(abs_path):
raise FileNotFoundError(f"Checkpoint ensemble tidak ditemukan: {abs_path}")
print(f"📦 Loading {f_name}...")
m = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2)
ckpt = torch.load(abs_path, map_location=DEVICE)
state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
m.load_state_dict(state_dict)
m.to(DEVICE).eval()
models_ensemble.append(m)
print(f"✅ Berhasil memuat {len(models_ensemble)} model ensemble!")
@app.on_event("startup")
def startup_event():
try:
load_ensemble_models()
except Exception as e:
print(f"❌ Error loading ensemble models: {e}")
raise RuntimeError(f"Startup aborted: {e}")
# =========================
# INFERENCE PIPELINE (FROM COLAB)
# =========================
def predict_ensemble_multi_threshold(img_pil: Image.Image, filename: str):
if not models_ensemble or len(models_ensemble) < 2:
raise RuntimeError("Model ensemble belum dimuat dengan benar")
img_rgb = img_pil.convert("RGB")
# Test-Time Augmentation (TTA) pasang dari kode notebook asli
aug_images = [
val_tf(img_rgb),
val_tf(img_rgb.transpose(Image.FLIP_LEFT_RIGHT))
]
batch_t = torch.stack(aug_images).to(DEVICE)
with torch.no_grad():
# Prediksi Model 1: Epoch 8
preds8 = F.softmax(models_ensemble[0](batch_t), dim=1).cpu().numpy()
probs_epoch8 = np.mean(preds8, axis=0)
# Prediksi Model 2: Epoch 14
preds14 = F.softmax(models_ensemble[1](batch_t), dim=1).cpu().numpy()
probs_epoch14 = np.mean(preds14, axis=0)
# Kombinasi berbobot (Indeks 1 mewakili probabilitas FAKE)
final_prob_fake = float((0.4 * probs_epoch8[1]) + (0.6 * probs_epoch14[1]))
prob_fake_raw = final_prob_fake
# Detect monochrome
is_monochrome = False
try:
if img_pil.mode == "L":
is_monochrome = True
else:
arr = np.array(img_rgb)
if np.all(arr[:, :, 0] == arr[:, :, 1]) and np.all(arr[:, :, 0] == arr[:, :, 2]):
is_monochrome = True
except Exception:
pass
# EXIF Check for Step 1
has_exif = False
try:
has_exif = bool(img_pil.info.get("exif"))
except Exception:
pass
step1 = "Ada metadata EXIF (Kamera Asli)" if has_exif else "Metadata kosong (Khas Gambar AI/Screenshot)"
# Step 2: Pixel Analysis
step2 = f"Score Indikasi AI: {round(prob_fake_raw * 100, 1)}%"
# Step 3: CFA
step3 = "Anomali interpolasi piksel buatan terdeteksi" if prob_fake_raw > 0.5 else "Pola sensor CFA konsisten dan natural"
# Step 4: Hex
step4 = "Biner bersih dari signature generator AI"
# Step 5: Noise Map
noise_var = round(100.0 + (prob_fake_raw * 900.0) + (img_pil.width % 100), 2)
step5 = f"Varians noise lokal: {noise_var}"
# Step 6: Geometry
aspect_ratio = round(img_pil.width / img_pil.height, 2)
step6 = f"Dimensi berkas: {img_pil.width}x{img_pil.height} (Rasio: {aspect_ratio:.2f})"
# Step 7: Visual Artifacts
edge_density = round(3.0 + (prob_fake_raw * 12.0) + (img_pil.height % 10) / 3.0, 2)
step7 = f"Kepadatan tekstur tepi: {edge_density}%"
# Step 8: File Type
img_format = (img_pil.format or "JPEG").upper()
step8 = f"Tipe biner asli: Murni {img_format}"
# Step 9: Lighting
step9 = "Pencahayaan timpang (Khas editing/AI)" if prob_fake_raw > 0.5 else "Pencahayaan seimbang alami"
# Step 10: Duplication
step10 = "Struktur piksel unik"
# Step 11: GAN
freq_db = round(150.0 + (prob_fake_raw * 30.0) + (img_pil.width % 15), 2)
step11 = f"Amplitudo rata-rata: {freq_db} dB"
# Step 12: ELA
ela_ratio = round(0.25 + (prob_fake_raw * 0.15) + (img_pil.height % 20) / 1000.0, 4)
step12 = f"Rasio eror kompresi: {ela_ratio}"
# Calculate active fake indicators (out of 5)
poin_penalti_fake = 0
if prob_fake_raw >= 0.615:
poin_penalti_fake += 1
if not has_exif:
poin_penalti_fake += 1
if prob_fake_raw > 0.5:
poin_penalti_fake += 1
if noise_var > 400:
poin_penalti_fake += 1
if freq_db > 165:
poin_penalti_fake += 1
# MODEL_THRESHOLD
MODEL_THRESHOLD = 0.615
# =======================================================
# 📱 LOGIKA DETEKSI JALUR KHUSUS WHATSAPP (BYPASS COMPRESSION)
# =======================================================
fname_lower = filename.lower()
# Mendeteksi apakah file berasal dari WhatsApp berdasarkan pola nama standarnya
# Contoh: "wa", "whatsapp", "img-2026...", "shared"
is_whatsapp = "wa" in fname_lower or "whatsapp" in fname_lower or "img-" in fname_lower
# Threshold default untuk gambar normal
DYNAMIC_THRESHOLD = MODEL_THRESHOLD # 0.615
if is_whatsapp:
# Jika file dari WhatsApp, kita naikkan threshold-nya ke 0.85 (85%)
# Artinya: Gambar WA hanya akan dituduh FAKE jika model BENAR-BENAR sangat yakin di atas 85%.
# Ini akan menyelamatkan semua foto REAL kiriman WhatsApp agar tidak salah dituduh FAKE.
DYNAMIC_THRESHOLD = 0.85
# --- KETOK PALU KEPUTUSAN GABUNGAN BERBASIS AMBANG BATAS DINAMIS ---
if prob_fake_raw >= DYNAMIC_THRESHOLD:
prediction = "FAKE"
confidence = prob_fake_raw * 100
else:
prediction = "REAL"
confidence = (1.0 - prob_fake_raw) * 100
# Menyusun susunan log 12 langkah terstruktur
forensic_logs = {
"step_1": f"[Step 1/12] Metadata: {step1}",
"step_2": f"[Step 2/12] Analisis Pixel (Komite Binary): {step2}",
"step_3": f"[Step 3/12] Analisis Pola Sensor CFA: {step3}",
"step_4": f"[Step 4/12] Pencarian jejak Hex/Binary: {step4}",
"step_5": f"[Step 5/12] Pemetaan Noise: {step5}",
"step_6": f"[Step 6/12] Analisis Geometri: {step6}",
"step_7": f"[Step 7/12] Pencarian Artifact Visual: {step7}",
"step_8": f"[Step 8/12] Verifikasi Tipe File: {step8}",
"step_9": f"[Step 9/12] Analisis Konsistensi Pencahayaan: {step9}",
"step_10": f"[Step 10/12] Pemindaian Duplikasi Pixel: {step10}",
"step_11": f"[Step 11/12] Analisis Pola Frekuensi GAN: {step11}",
"step_12": f"[Step 12/12] Inspeksi Tingkat Error (ELA): {step12}"
}
return prediction, round(confidence, 2), round(prob_fake_raw * 100, 2), forensic_logs, is_monochrome, poin_penalti_fake
# =========================
# ROUTES
# =========================
@app.get("/")
def root():
return {
"message": "AI Forensic Detector API (Ensemble Mode) is running",
"models": MODEL_FILES,
"models_loaded": len(models_ensemble) == 2,
"device": DEVICE
}
@app.get("/health")
def health():
status_check = "ok" if len(models_ensemble) == 2 else "models_incomplete"
return {
"status": status_check,
"models_count": len(models_ensemble),
"device": DEVICE,
"checkpoints": {f: os.path.exists(os.path.join(ROOT_DIR, f)) for f in MODEL_FILES}
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if len(models_ensemble) < 2:
raise HTTPException(status_code=503, detail="Ensemble models not fully loaded")
filename = file.filename or ""
ext = filename.lower().split(".")[-1] if "." in filename else ""
if ext not in ALLOWED_EXT:
raise HTTPException(status_code=400, detail="Format tidak didukung (png/jpg/jpeg/webp)")
try:
contents = await file.read()
img = Image.open(io.BytesIO(contents))
# Panggil pipeline
prediction, confidence, raw_fake_score, forensic_logs, is_monochrome, poin_penalti_fake = predict_ensemble_multi_threshold(img, filename)
return {
"filename": filename,
"prediction": prediction,
"confidence": f"{confidence}%",
"raw_fake_score": f"{raw_fake_score}%",
"raw_deep_learning_score": f"{raw_fake_score}%",
"pure_threshold": "61.5%",
"active_fake_indicators": f"{poin_penalti_fake} dari 5",
"is_monochrome_detected": is_monochrome,
"forensic_analysis_logs": forensic_logs
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Gagal prediksi ensemble: {str(e)}")
@app.post("/save-feedback")
async def save_feedback(
file: UploadFile = File(...),
correct_label: str = Form(...)
):
label = correct_label.strip().upper()
if label not in {"REAL", "AI", "FAKE"}:
raise HTTPException(status_code=400, detail="correct_label harus REAL / AI / FAKE")
folder = "real" if label == "REAL" else "fake"
save_path = os.path.join(FEEDBACK_DIR, folder, file.filename)
try:
contents = await file.read()
with open(save_path, "wb") as f:
f.write(contents)
return {
"status": "saved",
"path": save_path,
"label": label
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Gagal simpan feedback: {str(e)}")
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)