Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| import torchcodec # Added as requested to handle audio parsing | |
| import soundfile # noqa: F401 — diperlukan sebagai backend torchaudio | |
| import numpy as np | |
| app = FastAPI() | |
| # =============== KONFIGURASI ================== | |
| SR = 16000 | |
| N_MFCC = 40 | |
| N_MELS = 64 | |
| # Sesuai urutan label pada dataset (0: real, 1: fake) | |
| LABELS = ["real", "fake"] | |
| # =============== ARSITEKTUR MODEL ================== | |
| class HybridAudioCNN(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super().__init__() | |
| # 1D CNN for waveform | |
| self.waveform_branch = nn.Sequential( | |
| nn.Conv1d(1, 32, 5, stride=2, padding=2), nn.BatchNorm1d(32), nn.ReLU(), | |
| nn.Conv1d(32, 64, 5, stride=2, padding=2), nn.BatchNorm1d(64), nn.ReLU(), | |
| nn.Conv1d(64, 128, 5, stride=2, padding=2), nn.BatchNorm1d(128), nn.ReLU(), | |
| nn.AdaptiveAvgPool1d(32) # -> [B,128,32] | |
| ) | |
| # 2D CNN for MFCC | |
| self.mfcc_branch = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((8, 8)) # -> [B,128,8,8] | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(128 * (32 + 8*8), 256), | |
| nn.ReLU(), nn.Dropout(0.4), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, waveform, mfcc): | |
| x1 = waveform.unsqueeze(1) | |
| x1 = self.waveform_branch(x1) | |
| x1 = x1.flatten(1) | |
| x2 = mfcc.unsqueeze(1) | |
| x2 = self.mfcc_branch(x2) | |
| x2 = x2.flatten(1) | |
| x = torch.cat([x1, x2], dim=1) | |
| return self.classifier(x) | |
| # =============== INISIALISASI MODEL ================== | |
| model = HybridAudioCNN(num_classes=2) | |
| # Pastikan file best_hybrid_cnn.pth (atau last_model.pth) ada di folder yang sama | |
| try: | |
| model.load_state_dict(torch.load("best_hybrid_cnn.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| except Exception as e: | |
| print(f"Peringatan: Gagal memuat model. Pastikan file best_hybrid_cnn.pth tersedia. Error: {e}") | |
| # =============== AUDIO PREPROCESSING ================== | |
| import torch.nn.functional as F | |
| # =============== AUDIO PREPROCESSING (DIUBAH UNTUK CHUNKING) ================== | |
| def load_and_preprocess_audio_chunks(file_path, target_seconds=2.0): | |
| # 1. Load & Resample | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if sample_rate != SR: | |
| resample_transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=SR) | |
| waveform = resample_transform(waveform) | |
| waveform = waveform.squeeze(0) # Bentuk menjadi 1D [T] | |
| # 2. Setup ukuran chunk (2 detik * 16000 = 32000 sampel) | |
| chunk_samples = int(target_seconds * SR) | |
| total_samples = waveform.shape[0] | |
| wave_chunks = [] | |
| mfcc_chunks = [] | |
| # Inisialisasi transformator MFCC | |
| mfcc_transform = torchaudio.transforms.MFCC( | |
| sample_rate=SR, n_mfcc=N_MFCC, | |
| melkwargs={"n_fft": 512, "n_mels": N_MELS, "hop_length": 160, "f_min": 80, "f_max": 7600} | |
| ) | |
| # 3. Looping untuk memotong audio tiap 2 detik | |
| for i in range(0, total_samples, chunk_samples): | |
| chunk = waveform[i : i + chunk_samples] | |
| # Jika potongan terakhir kurang dari 2 detik, tambal dengan angka 0 (padding) | |
| if chunk.shape[0] < chunk_samples: | |
| pad_length = chunk_samples - chunk.shape[0] | |
| chunk = F.pad(chunk, (0, pad_length)) | |
| # Ekstrak MFCC khusus untuk potongan ini | |
| mfcc = mfcc_transform(chunk) | |
| mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-6) | |
| wave_chunks.append(chunk) | |
| mfcc_chunks.append(mfcc) | |
| # 4. Gabungkan list menjadi Batch Tensor | |
| # Hasil: wave_batch [Batch, 32000], mfcc_batch [Batch, N_MFCC, Time] | |
| wave_batch = torch.stack(wave_chunks) | |
| mfcc_batch = torch.stack(mfcc_chunks) | |
| return wave_batch, mfcc_batch | |
| # =============== API ENDPOINTS ================== | |
| async def root(): | |
| return {"message": "Audio Deepfake Detector (Hybrid CNN) API is running!"} | |
| # =============== API ENDPOINTS (DIUBAH UNTUK INFERENSI BATCH) ================== | |
| async def predict_audio(file: UploadFile = File(...)): | |
| try: | |
| ext = os.path.splitext(file.filename)[1] if file.filename else ".wav" | |
| if not ext: | |
| ext = ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file: | |
| tmp_file.write(await file.read()) | |
| tmp_path = tmp_file.name | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail="Gagal menyimpan file audio sementara.") | |
| # 2. Proses Audio (Sekarang mengembalikan Batch Tensor) | |
| try: | |
| waveform_batch, mfcc_batch = load_and_preprocess_audio_chunks(tmp_path, target_seconds=2.0) | |
| except Exception as e: | |
| os.remove(tmp_path) | |
| raise HTTPException(status_code=400, detail=f"Gagal memproses file audio. Detail: {str(e)}") | |
| os.remove(tmp_path) # Bersihkan file berhasil | |
| # 3 & 4. Inferensi Model (Langsung masukkan batch, tidak perlu di-unsqueeze lagi) | |
| with torch.no_grad(): | |
| output = model(waveform_batch, mfcc_batch) | |
| probs = torch.softmax(output, dim=1) # Hasilnya [Jumlah_Chunk, 2] | |
| # Hitung rata-rata probabilitas dari seluruh chunk audio | |
| avg_probs = torch.mean(probs, dim=0).numpy().tolist() | |
| # 5. Ekstrak Hasil | |
| pred_idx = int(np.argmax(avg_probs)) | |
| pred_label = LABELS[pred_idx] | |
| confidence = avg_probs[pred_idx] | |
| return { | |
| "filename": file.filename, | |
| "total_chunks_processed": waveform_batch.shape[0], # Info tambahan berapa chunk yang dicek | |
| "prediction": avg_probs, | |
| "label": pred_label, | |
| "confidence": confidence, | |
| "details": { | |
| "real_probability": avg_probs[0], | |
| "fake_probability": avg_probs[1] | |
| } | |
| } |