File size: 6,431 Bytes
13f9583
 
 
 
 
 
4b5b31b
0df8ded
13f9583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c958a0
 
 
 
 
4b5b31b
13f9583
 
 
 
 
 
 
6c958a0
 
 
 
 
 
 
 
 
 
13f9583
 
6c958a0
13f9583
6c958a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f9583
6c958a0
 
13f9583
 
 
 
 
 
6c958a0
13f9583
 
 
4b5b31b
 
 
 
13f9583
 
 
 
 
6c958a0
13f9583
6c958a0
13f9583
6c958a0
 
13f9583
 
 
6c958a0
13f9583
6c958a0
 
 
 
 
13f9583
 
6c958a0
13f9583
6c958a0
13f9583
 
 
6c958a0
 
13f9583
 
 
6c958a0
 
13f9583
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import tempfile
from fastapi import FastAPI, UploadFile, File, HTTPException
import torch
import torch.nn as nn
import torchaudio
import torchcodec  # Added as requested to handle audio parsing
import soundfile  # noqa: F401  — diperlukan sebagai backend torchaudio
import numpy as np

app = FastAPI()

# =============== KONFIGURASI ==================
SR = 16000
N_MFCC = 40
N_MELS = 64
# Sesuai urutan label pada dataset (0: real, 1: fake)
LABELS = ["real", "fake"]

# =============== ARSITEKTUR MODEL ==================
class HybridAudioCNN(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        # 1D CNN for waveform
        self.waveform_branch = nn.Sequential(
            nn.Conv1d(1, 32, 5, stride=2, padding=2), nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 64, 5, stride=2, padding=2), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Conv1d(64, 128, 5, stride=2, padding=2), nn.BatchNorm1d(128), nn.ReLU(),
            nn.AdaptiveAvgPool1d(32)   # -> [B,128,32]
        )
        # 2D CNN for MFCC
        self.mfcc_branch = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d((8, 8))   # -> [B,128,8,8]
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * (32 + 8*8), 256),
            nn.ReLU(), nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def forward(self, waveform, mfcc):
        x1 = waveform.unsqueeze(1)                 
        x1 = self.waveform_branch(x1)              
        x1 = x1.flatten(1)                         

        x2 = mfcc.unsqueeze(1)                     
        x2 = self.mfcc_branch(x2)                  
        x2 = x2.flatten(1)                         

        x = torch.cat([x1, x2], dim=1)
        return self.classifier(x)

# =============== INISIALISASI MODEL ==================
model = HybridAudioCNN(num_classes=2)
# Pastikan file best_hybrid_cnn.pth (atau last_model.pth) ada di folder yang sama
try:
    model.load_state_dict(torch.load("best_hybrid_cnn.pth", map_location=torch.device('cpu')))
    model.eval()
except Exception as e:
    print(f"Peringatan: Gagal memuat model. Pastikan file best_hybrid_cnn.pth tersedia. Error: {e}")

# =============== AUDIO PREPROCESSING ==================
import torch.nn.functional as F

# =============== AUDIO PREPROCESSING (DIUBAH UNTUK CHUNKING) ==================
def load_and_preprocess_audio_chunks(file_path, target_seconds=2.0):
    # 1. Load & Resample
    waveform, sample_rate = torchaudio.load(file_path)
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
        
    if sample_rate != SR:
        resample_transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=SR)
        waveform = resample_transform(waveform)
        
    waveform = waveform.squeeze(0) # Bentuk menjadi 1D [T]
    
    # 2. Setup ukuran chunk (2 detik * 16000 = 32000 sampel)
    chunk_samples = int(target_seconds * SR)
    total_samples = waveform.shape[0]
    
    wave_chunks = []
    mfcc_chunks = []
    
    # Inisialisasi transformator MFCC
    mfcc_transform = torchaudio.transforms.MFCC(
        sample_rate=SR, n_mfcc=N_MFCC,
        melkwargs={"n_fft": 512, "n_mels": N_MELS, "hop_length": 160, "f_min": 80, "f_max": 7600}
    )

    # 3. Looping untuk memotong audio tiap 2 detik
    for i in range(0, total_samples, chunk_samples):
        chunk = waveform[i : i + chunk_samples]
        
        # Jika potongan terakhir kurang dari 2 detik, tambal dengan angka 0 (padding)
        if chunk.shape[0] < chunk_samples:
            pad_length = chunk_samples - chunk.shape[0]
            chunk = F.pad(chunk, (0, pad_length))
            
        # Ekstrak MFCC khusus untuk potongan ini
        mfcc = mfcc_transform(chunk)
        mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-6)
        
        wave_chunks.append(chunk)
        mfcc_chunks.append(mfcc)
        
    # 4. Gabungkan list menjadi Batch Tensor 
    # Hasil: wave_batch [Batch, 32000], mfcc_batch [Batch, N_MFCC, Time]
    wave_batch = torch.stack(wave_chunks)
    mfcc_batch = torch.stack(mfcc_chunks)
    
    return wave_batch, mfcc_batch


# =============== API ENDPOINTS ==================
@app.get("/")
async def root():
    return {"message": "Audio Deepfake Detector (Hybrid CNN) API is running!"}

# =============== API ENDPOINTS (DIUBAH UNTUK INFERENSI BATCH) ==================
@app.post("/predict-audio")
async def predict_audio(file: UploadFile = File(...)):
    try:
        ext = os.path.splitext(file.filename)[1] if file.filename else ".wav"
        if not ext:
            ext = ".wav"
        with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file:
            tmp_file.write(await file.read())
            tmp_path = tmp_file.name
    except Exception as e:
        raise HTTPException(status_code=500, detail="Gagal menyimpan file audio sementara.")

    # 2. Proses Audio (Sekarang mengembalikan Batch Tensor)
    try:
        waveform_batch, mfcc_batch = load_and_preprocess_audio_chunks(tmp_path, target_seconds=2.0)
    except Exception as e:
        os.remove(tmp_path) 
        raise HTTPException(status_code=400, detail=f"Gagal memproses file audio. Detail: {str(e)}")
    
    os.remove(tmp_path) # Bersihkan file berhasil

    # 3 & 4. Inferensi Model (Langsung masukkan batch, tidak perlu di-unsqueeze lagi)
    with torch.no_grad():
        output = model(waveform_batch, mfcc_batch) 
        probs = torch.softmax(output, dim=1) # Hasilnya [Jumlah_Chunk, 2]
        
        # Hitung rata-rata probabilitas dari seluruh chunk audio
        avg_probs = torch.mean(probs, dim=0).numpy().tolist()

    # 5. Ekstrak Hasil
    pred_idx = int(np.argmax(avg_probs))
    pred_label = LABELS[pred_idx]
    confidence = avg_probs[pred_idx]

    return {
        "filename": file.filename,
        "total_chunks_processed": waveform_batch.shape[0], # Info tambahan berapa chunk yang dicek
        "prediction": avg_probs,
        "label": pred_label,
        "confidence": confidence,
        "details": {
            "real_probability": avg_probs[0],
            "fake_probability": avg_probs[1]
        }
    }