File size: 7,839 Bytes
56053bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8e447a
 
56053bc
 
 
 
 
 
 
 
c8e447a
 
 
 
 
b5d7ce9
 
c8e447a
 
56053bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51e1eef
b0a98f1
51e1eef
 
 
ec84c20
51e1eef
 
 
 
 
 
 
 
56053bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51e1eef
56053bc
 
51e1eef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56053bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
038119e
51e1eef
56053bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278884e
1d048b9
278884e
 
1d048b9
278884e
a952b6f
278884e
a952b6f
278884e
75f3dc6
278884e
56053bc
51e1eef
56053bc
 
 
 
 
 
 
 
 
 
 
 
 
278884e
56053bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f19927
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import json
import torch
import torchaudio
import numpy as np
from torch import nn
import torch.nn.functional as F
from fastapi import FastAPI, UploadFile, HTTPException, File
import nest_asyncio
import uvicorn
from model_utils import (
    Model,
)  # Assurez-vous que la classe Model est correctement importée
import os
import soundfile as sf
import io
import tempfile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from typing import List, Union
from calculate_modules import compute_eer  # Importer la fonction compute_eer
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)  #
# App FastAPI
app = FastAPI()

# Serve static frontend files
app.mount("/static", StaticFiles(directory="Web"), name="static")


@app.get("/")
def home():
    return FileResponse("Web/index.html")


# Configurer CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Autorise toutes les origines (peut être restreint)
    allow_credentials=True,
    allow_methods=["*"],  # Autorise toutes les méthodes HTTP
    allow_headers=["*"],  # Autorise tous les en-têtes
)


# Charger la configuration du modèle
def load_config(config_path):
    try:
        with open(config_path, "r") as f:
            return json.load(f)
    except Exception as e:
        print(f"Erreur lors du chargement de la configuration : {e}")
        raise HTTPException(
            status_code=500,
            detail=f"Erreur lors du chargement de la configuration: {e}",
        )


# Charger le modèle
def load_model(checkpoint_path, d_args):
    models_list=[]
    for i in range(len(checkpoint_path)):
        model = Model(d_args)
        try:
            # Load checkpoint
            checkpoint = torch.load(checkpoint_path[i], map_location=torch.device("cpu"))
            model.load_state_dict(checkpoint)
            print(f"Model_{i} loaded successfully.")
        except Exception as e:
            print(f"Error loading model_{i}: {e}")
            raise
        model.eval()
        models_list.append(model)
    return models_list


# Prétraiter l'audio
def preprocess_audio(audio_path, sample_rate=16000):
    try:
        print(f"Chargement de l'audio: {audio_path}")
        waveform, sr = torchaudio.load(audio_path)
        print(f"Audio chargé: {audio_path}, Taux d'échantillonnage: {sr}")
        if sr != sample_rate:
            resample_transform = torchaudio.transforms.Resample(
                orig_freq=sr, new_freq=sample_rate
            )
            waveform = resample_transform(waveform)
        if waveform.size(0) > 1:
            waveform = torch.mean(
                waveform, dim=0, keepdim=True
            )  # Convertir en mono si stéréo
        return waveform
    except Exception as e:
        print(f"Erreur dans le prétraitement audio : {e}")
        raise HTTPException(
            status_code=500, detail=f"Erreur dans le prétraitement de l'audio: {e}"
        )


def infer(model_list, waveform, freq_aug=False):
    try:
        with torch.no_grad():
            probabilities_sum=None
            for model in model_list:
                last_hidden, output = model(waveform, Freq_aug=freq_aug)
                print("Sortie du modèle:", output)
                if output is None:
                    raise ValueError("La sortie du modèle est nulle.")
                probabilities = F.softmax(output, dim=1)
                if probabilities_sum is None:
                    probabilities_sum=probabilities
                else:
                    probabilities_sum+=np.array(probabilities[0].tolist())
                    
            probabilities_sum=probabilities_sum/len(model_list)
            predicted_label = torch.argmax(probabilities_sum, dim=1).item()
            confidence = probabilities_sum[
                0
            ].tolist()  # Liste des probabilités pour toutes les classes
            max_confidence = 1 - max(confidence)  # La probabilité la plus élevée
            return (
                predicted_label,
                max_confidence,
            )  # Retourner également la probabilité la plus élevée
    except Exception as e:
        print(f"Erreur pendant l'inférence : {e}")
        raise


# Charger le modèle d'exemple
config_path = "./AASIST_ASVspoof5_Exp4_CL.conf"  # Remplacez par le chemin réel de votre fichier de config
config = load_config(config_path)
d_args = config["model_config"]
checkpoint_path = [f"./S{i+1}_best.pth" for i in range(4)]  # Remplacez par votre checkpoint
model_list = load_model(checkpoint_path, d_args)


@app.post("/predict/")
async def predict(files: List[UploadFile] = File(...)):
    """
    Endpoint to handle batch inference for multiple audio files.
    Accepts a list of audio files and returns inference results for each file.
    """
    responses = []
    bonafide_scores = []
    spoof_scores = []

    for file in files:
        try:
            logger.info(f"Processing file: {file.filename}")

            # Validate file format
            if not file.filename.endswith((".wav", ".flac")):
                raise HTTPException(
                    status_code=400,
                    detail="Invalid file format. Only .wav and .flac files are allowed.",
                )

            # Save the uploaded file temporarily
            with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
                temp_audio_path = temp_audio.name
                temp_audio.write(await file.read())

            # Preprocess the audio file
            waveform = preprocess_audio(temp_audio_path)

            rms = torch.sqrt(torch.mean(waveform**2)).item()
            max_eng = torch.sqrt(torch.max(waveform**2)).item()
            volume = "normal"

            if rms < 0.01:
                volume = "silent"
            elif rms < 0.05 and max_eng<0.3:
                volume = "augmented"
                waveform = waveform * (1/(3*rms))  # Apply gain boost

            logger.info(f"Volume state for {file.filename}: rms={rms},max_eng={max_eng} > {volume}")

            # Perform inference
            label, confidence = infer(model_list, waveform)

            # Store scores for EER calculation
            if label == 0:  # Bonafide
                bonafide_scores.append(confidence)
            else:  # Spoof
                spoof_scores.append(confidence)

            # Prepare the response
            response = {
                "filename": file.filename,
                "label": "Genuine" if label == 0 else "Spoof",
                "confidence": confidence,
                "status": "success",
                "volume":volume
            }
            responses.append(response)

            # Clean up the temporary file
            os.unlink(temp_audio_path)

        except Exception as e:
            responses.append(
                {"filename": file.filename, "error": str(e), "status": "failed"}
            )

    # Log collected scores for debugging
    logger.info(f"Bonafide scores: {bonafide_scores}")
    logger.info(f"Spoof scores: {spoof_scores}")

    # Calculate EER if we have scores for both bonafide and spoof
    if bonafide_scores and spoof_scores:
        eer, _, _, _ = compute_eer(np.array(bonafide_scores), np.array(spoof_scores))
        eer_percentage = eer * 100
        responses.append({"EER": f"{eer_percentage:.2f}%"})
        logger.info(f"Calculated EER: {eer_percentage:.2f}%")
    else:
        logger.info("Not enough data to calculate EER.")

    return responses


# Exécuter le serveur FastAPI dans Colab
nest_asyncio.apply()
import uvicorn

uvicorn.run(
        app, 
        host="0.0.0.0", 
        port=7860, 
        proxy_headers=True, 
        forwarded_allow_ips="*"
    )