Syahhh01's picture
Update app.py
c9fd5f9 verified
from contextlib import asynccontextmanager
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Annotated
from typing import Any
from pydantic import BaseModel, Field
from genai_service import (
generate_detection_analysis
)
import tensorflow as tf
from fastapi import (
FastAPI,
File,
Form,
HTTPException,
UploadFile
)
from custom_layers import (
AdaptiveAvgPool1D,
AdaptiveAvgPool2D
)
from inference import predict_audio
# ============================================================
# CONFIGURATION
# ============================================================
MODEL_PATH = Path(
"best_torchlike_mfcc_waveform_model.keras"
)
ALLOWED_EXTENSIONS = {
".wav",
".mp3",
".flac",
".ogg",
".m4a"
}
MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE_BYTES = (
MAX_FILE_SIZE_MB
* 1024
* 1024
)
model: tf.keras.Model | None = None
# ============================================================
# LOAD MODEL ON STARTUP
# ============================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
if not MODEL_PATH.exists():
raise FileNotFoundError(
f"Model tidak ditemukan: {MODEL_PATH}"
)
print("Loading model...")
model = tf.keras.models.load_model(
MODEL_PATH,
custom_objects={
"AdaptiveAvgPool1D": AdaptiveAvgPool1D,
"AdaptiveAvgPool2D": AdaptiveAvgPool2D
},
compile=False
)
print("Model loaded successfully.")
yield
model = None
# ============================================================
# FASTAPI APP
# ============================================================
app = FastAPI(
title="Deepfake Audio Detection API",
description=(
"REST API untuk mendeteksi audio real atau fake "
"menggunakan model MFCC + Waveform."
),
version="1.0.0",
lifespan=lifespan
)
# ============================================================
# GENERATIVE AI REQUEST SCHEMA
# ============================================================
class DetectionAnalysisRequest(BaseModel):
prediction: str = Field(
pattern="^(real|fake)$"
)
threshold: float = Field(
ge=0.0,
le=1.0
)
total_clips: int = Field(
ge=1
)
real_clips: int = Field(
ge=0
)
fake_clips: int = Field(
ge=0
)
average_probability_real: float = Field(
ge=0.0,
le=1.0
)
average_probability_fake: float = Field(
ge=0.0,
le=1.0
)
# ============================================================
# ROUTES
# ============================================================
@app.get("/")
def root():
return {
"message": "Deepfake Audio Detection API",
"status": "running",
"docs": "/docs",
"predict_endpoint": "/predict",
"default_threshold": 0.60
}
@app.get("/health")
def health():
return {
"status": (
"healthy"
if model is not None
else "model_not_loaded"
),
"model_loaded": model is not None
}
@app.post("/predict")
async def predict(
file: Annotated[
UploadFile,
File(
description=(
"File audio dengan format WAV, MP3, "
"FLAC, OGG, atau M4A."
)
)
],
threshold: Annotated[
float,
Form(
ge=0.0,
le=1.0,
description=(
"Audio dianggap fake jika probability_fake "
"lebih besar atau sama dengan threshold."
)
)
] = 0.60
):
"""
Prediksi apakah audio termasuk real atau fake.
Default threshold:
0.60
Threshold dapat diubah pada setiap request.
"""
if model is None:
raise HTTPException(
status_code=503,
detail="Model belum siap digunakan."
)
original_filename = file.filename or "uploaded_audio.wav"
suffix = Path(
original_filename
).suffix.lower()
if suffix not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=(
"Format audio tidak didukung. "
"Gunakan WAV, MP3, FLAC, OGG, atau M4A."
)
)
file_content = await file.read()
if len(file_content) == 0:
raise HTTPException(
status_code=400,
detail="File audio kosong."
)
if len(file_content) > MAX_FILE_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=(
f"Ukuran file terlalu besar. "
f"Maksimal {MAX_FILE_SIZE_MB} MB."
)
)
temp_path: Path | None = None
try:
with NamedTemporaryFile(
delete=False,
suffix=suffix
) as temp_file:
temp_file.write(file_content)
temp_path = Path(
temp_file.name
)
result = predict_audio(
model=model,
file_path=temp_path,
threshold=threshold
)
return {
"filename": original_filename,
**result
}
except ValueError as error:
raise HTTPException(
status_code=400,
detail=str(error)
) from error
except Exception as error:
raise HTTPException(
status_code=500,
detail=f"Inference gagal: {str(error)}"
) from error
finally:
if (
temp_path is not None
and temp_path.exists()
):
temp_path.unlink()
# ============================================================
# GENERATIVE AI ANALYSIS ENDPOINT
# ============================================================
@app.post("/generate-analysis")
def generate_analysis(
request: DetectionAnalysisRequest
):
"""
Membuat penjelasan hasil prediksi menggunakan Gemini API.
Endpoint ini merupakan fitur sekunder.
Label prediksi tetap berasal dari model TensorFlow.
"""
if (
request.real_clips
+ request.fake_clips
!= request.total_clips
):
raise HTTPException(
status_code=400,
detail=(
"Jumlah real_clips dan fake_clips "
"harus sama dengan total_clips."
)
)
try:
analysis = generate_detection_analysis(
detection_result=(
request.model_dump()
)
)
return {
"prediction": request.prediction,
"analysis": analysis
}
except RuntimeError as error:
raise HTTPException(
status_code=503,
detail=str(error)
) from error
except Exception as error:
raise HTTPException(
status_code=500,
detail=(
"Gagal membuat analisis AI: "
f"{str(error)}"
)
) from error