Syahhh01's picture
Update app.py
6c958a0 verified
import os
import tempfile
from fastapi import FastAPI, UploadFile, File, HTTPException
import torch
import torch.nn as nn
import torchaudio
import torchcodec # Added as requested to handle audio parsing
import soundfile # noqa: F401 — diperlukan sebagai backend torchaudio
import numpy as np
app = FastAPI()
# =============== KONFIGURASI ==================
SR = 16000
N_MFCC = 40
N_MELS = 64
# Sesuai urutan label pada dataset (0: real, 1: fake)
LABELS = ["real", "fake"]
# =============== ARSITEKTUR MODEL ==================
class HybridAudioCNN(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
# 1D CNN for waveform
self.waveform_branch = nn.Sequential(
nn.Conv1d(1, 32, 5, stride=2, padding=2), nn.BatchNorm1d(32), nn.ReLU(),
nn.Conv1d(32, 64, 5, stride=2, padding=2), nn.BatchNorm1d(64), nn.ReLU(),
nn.Conv1d(64, 128, 5, stride=2, padding=2), nn.BatchNorm1d(128), nn.ReLU(),
nn.AdaptiveAvgPool1d(32) # -> [B,128,32]
)
# 2D CNN for MFCC
self.mfcc_branch = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.AdaptiveAvgPool2d((8, 8)) # -> [B,128,8,8]
)
self.classifier = nn.Sequential(
nn.Linear(128 * (32 + 8*8), 256),
nn.ReLU(), nn.Dropout(0.4),
nn.Linear(256, num_classes)
)
def forward(self, waveform, mfcc):
x1 = waveform.unsqueeze(1)
x1 = self.waveform_branch(x1)
x1 = x1.flatten(1)
x2 = mfcc.unsqueeze(1)
x2 = self.mfcc_branch(x2)
x2 = x2.flatten(1)
x = torch.cat([x1, x2], dim=1)
return self.classifier(x)
# =============== INISIALISASI MODEL ==================
model = HybridAudioCNN(num_classes=2)
# Pastikan file best_hybrid_cnn.pth (atau last_model.pth) ada di folder yang sama
try:
model.load_state_dict(torch.load("best_hybrid_cnn.pth", map_location=torch.device('cpu')))
model.eval()
except Exception as e:
print(f"Peringatan: Gagal memuat model. Pastikan file best_hybrid_cnn.pth tersedia. Error: {e}")
# =============== AUDIO PREPROCESSING ==================
import torch.nn.functional as F
# =============== AUDIO PREPROCESSING (DIUBAH UNTUK CHUNKING) ==================
def load_and_preprocess_audio_chunks(file_path, target_seconds=2.0):
# 1. Load & Resample
waveform, sample_rate = torchaudio.load(file_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != SR:
resample_transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=SR)
waveform = resample_transform(waveform)
waveform = waveform.squeeze(0) # Bentuk menjadi 1D [T]
# 2. Setup ukuran chunk (2 detik * 16000 = 32000 sampel)
chunk_samples = int(target_seconds * SR)
total_samples = waveform.shape[0]
wave_chunks = []
mfcc_chunks = []
# Inisialisasi transformator MFCC
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=SR, n_mfcc=N_MFCC,
melkwargs={"n_fft": 512, "n_mels": N_MELS, "hop_length": 160, "f_min": 80, "f_max": 7600}
)
# 3. Looping untuk memotong audio tiap 2 detik
for i in range(0, total_samples, chunk_samples):
chunk = waveform[i : i + chunk_samples]
# Jika potongan terakhir kurang dari 2 detik, tambal dengan angka 0 (padding)
if chunk.shape[0] < chunk_samples:
pad_length = chunk_samples - chunk.shape[0]
chunk = F.pad(chunk, (0, pad_length))
# Ekstrak MFCC khusus untuk potongan ini
mfcc = mfcc_transform(chunk)
mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-6)
wave_chunks.append(chunk)
mfcc_chunks.append(mfcc)
# 4. Gabungkan list menjadi Batch Tensor
# Hasil: wave_batch [Batch, 32000], mfcc_batch [Batch, N_MFCC, Time]
wave_batch = torch.stack(wave_chunks)
mfcc_batch = torch.stack(mfcc_chunks)
return wave_batch, mfcc_batch
# =============== API ENDPOINTS ==================
@app.get("/")
async def root():
return {"message": "Audio Deepfake Detector (Hybrid CNN) API is running!"}
# =============== API ENDPOINTS (DIUBAH UNTUK INFERENSI BATCH) ==================
@app.post("/predict-audio")
async def predict_audio(file: UploadFile = File(...)):
try:
ext = os.path.splitext(file.filename)[1] if file.filename else ".wav"
if not ext:
ext = ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file:
tmp_file.write(await file.read())
tmp_path = tmp_file.name
except Exception as e:
raise HTTPException(status_code=500, detail="Gagal menyimpan file audio sementara.")
# 2. Proses Audio (Sekarang mengembalikan Batch Tensor)
try:
waveform_batch, mfcc_batch = load_and_preprocess_audio_chunks(tmp_path, target_seconds=2.0)
except Exception as e:
os.remove(tmp_path)
raise HTTPException(status_code=400, detail=f"Gagal memproses file audio. Detail: {str(e)}")
os.remove(tmp_path) # Bersihkan file berhasil
# 3 & 4. Inferensi Model (Langsung masukkan batch, tidak perlu di-unsqueeze lagi)
with torch.no_grad():
output = model(waveform_batch, mfcc_batch)
probs = torch.softmax(output, dim=1) # Hasilnya [Jumlah_Chunk, 2]
# Hitung rata-rata probabilitas dari seluruh chunk audio
avg_probs = torch.mean(probs, dim=0).numpy().tolist()
# 5. Ekstrak Hasil
pred_idx = int(np.argmax(avg_probs))
pred_label = LABELS[pred_idx]
confidence = avg_probs[pred_idx]
return {
"filename": file.filename,
"total_chunks_processed": waveform_batch.shape[0], # Info tambahan berapa chunk yang dicek
"prediction": avg_probs,
"label": pred_label,
"confidence": confidence,
"details": {
"real_probability": avg_probs[0],
"fake_probability": avg_probs[1]
}
}