File size: 7,839 Bytes
56053bc c8e447a 56053bc c8e447a b5d7ce9 c8e447a 56053bc 51e1eef b0a98f1 51e1eef ec84c20 51e1eef 56053bc 51e1eef 56053bc 51e1eef 56053bc 038119e 51e1eef 56053bc 278884e 1d048b9 278884e 1d048b9 278884e a952b6f 278884e a952b6f 278884e 75f3dc6 278884e 56053bc 51e1eef 56053bc 278884e 56053bc 5f19927 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import json
import torch
import torchaudio
import numpy as np
from torch import nn
import torch.nn.functional as F
from fastapi import FastAPI, UploadFile, HTTPException, File
import nest_asyncio
import uvicorn
from model_utils import (
Model,
) # Assurez-vous que la classe Model est correctement importée
import os
import soundfile as sf
import io
import tempfile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from typing import List, Union
from calculate_modules import compute_eer # Importer la fonction compute_eer
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) #
# App FastAPI
app = FastAPI()
# Serve static frontend files
app.mount("/static", StaticFiles(directory="Web"), name="static")
@app.get("/")
def home():
return FileResponse("Web/index.html")
# Configurer CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Autorise toutes les origines (peut être restreint)
allow_credentials=True,
allow_methods=["*"], # Autorise toutes les méthodes HTTP
allow_headers=["*"], # Autorise tous les en-têtes
)
# Charger la configuration du modèle
def load_config(config_path):
try:
with open(config_path, "r") as f:
return json.load(f)
except Exception as e:
print(f"Erreur lors du chargement de la configuration : {e}")
raise HTTPException(
status_code=500,
detail=f"Erreur lors du chargement de la configuration: {e}",
)
# Charger le modèle
def load_model(checkpoint_path, d_args):
models_list=[]
for i in range(len(checkpoint_path)):
model = Model(d_args)
try:
# Load checkpoint
checkpoint = torch.load(checkpoint_path[i], map_location=torch.device("cpu"))
model.load_state_dict(checkpoint)
print(f"Model_{i} loaded successfully.")
except Exception as e:
print(f"Error loading model_{i}: {e}")
raise
model.eval()
models_list.append(model)
return models_list
# Prétraiter l'audio
def preprocess_audio(audio_path, sample_rate=16000):
try:
print(f"Chargement de l'audio: {audio_path}")
waveform, sr = torchaudio.load(audio_path)
print(f"Audio chargé: {audio_path}, Taux d'échantillonnage: {sr}")
if sr != sample_rate:
resample_transform = torchaudio.transforms.Resample(
orig_freq=sr, new_freq=sample_rate
)
waveform = resample_transform(waveform)
if waveform.size(0) > 1:
waveform = torch.mean(
waveform, dim=0, keepdim=True
) # Convertir en mono si stéréo
return waveform
except Exception as e:
print(f"Erreur dans le prétraitement audio : {e}")
raise HTTPException(
status_code=500, detail=f"Erreur dans le prétraitement de l'audio: {e}"
)
def infer(model_list, waveform, freq_aug=False):
try:
with torch.no_grad():
probabilities_sum=None
for model in model_list:
last_hidden, output = model(waveform, Freq_aug=freq_aug)
print("Sortie du modèle:", output)
if output is None:
raise ValueError("La sortie du modèle est nulle.")
probabilities = F.softmax(output, dim=1)
if probabilities_sum is None:
probabilities_sum=probabilities
else:
probabilities_sum+=np.array(probabilities[0].tolist())
probabilities_sum=probabilities_sum/len(model_list)
predicted_label = torch.argmax(probabilities_sum, dim=1).item()
confidence = probabilities_sum[
0
].tolist() # Liste des probabilités pour toutes les classes
max_confidence = 1 - max(confidence) # La probabilité la plus élevée
return (
predicted_label,
max_confidence,
) # Retourner également la probabilité la plus élevée
except Exception as e:
print(f"Erreur pendant l'inférence : {e}")
raise
# Charger le modèle d'exemple
config_path = "./AASIST_ASVspoof5_Exp4_CL.conf" # Remplacez par le chemin réel de votre fichier de config
config = load_config(config_path)
d_args = config["model_config"]
checkpoint_path = [f"./S{i+1}_best.pth" for i in range(4)] # Remplacez par votre checkpoint
model_list = load_model(checkpoint_path, d_args)
@app.post("/predict/")
async def predict(files: List[UploadFile] = File(...)):
"""
Endpoint to handle batch inference for multiple audio files.
Accepts a list of audio files and returns inference results for each file.
"""
responses = []
bonafide_scores = []
spoof_scores = []
for file in files:
try:
logger.info(f"Processing file: {file.filename}")
# Validate file format
if not file.filename.endswith((".wav", ".flac")):
raise HTTPException(
status_code=400,
detail="Invalid file format. Only .wav and .flac files are allowed.",
)
# Save the uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio_path = temp_audio.name
temp_audio.write(await file.read())
# Preprocess the audio file
waveform = preprocess_audio(temp_audio_path)
rms = torch.sqrt(torch.mean(waveform**2)).item()
max_eng = torch.sqrt(torch.max(waveform**2)).item()
volume = "normal"
if rms < 0.01:
volume = "silent"
elif rms < 0.05 and max_eng<0.3:
volume = "augmented"
waveform = waveform * (1/(3*rms)) # Apply gain boost
logger.info(f"Volume state for {file.filename}: rms={rms},max_eng={max_eng} > {volume}")
# Perform inference
label, confidence = infer(model_list, waveform)
# Store scores for EER calculation
if label == 0: # Bonafide
bonafide_scores.append(confidence)
else: # Spoof
spoof_scores.append(confidence)
# Prepare the response
response = {
"filename": file.filename,
"label": "Genuine" if label == 0 else "Spoof",
"confidence": confidence,
"status": "success",
"volume":volume
}
responses.append(response)
# Clean up the temporary file
os.unlink(temp_audio_path)
except Exception as e:
responses.append(
{"filename": file.filename, "error": str(e), "status": "failed"}
)
# Log collected scores for debugging
logger.info(f"Bonafide scores: {bonafide_scores}")
logger.info(f"Spoof scores: {spoof_scores}")
# Calculate EER if we have scores for both bonafide and spoof
if bonafide_scores and spoof_scores:
eer, _, _, _ = compute_eer(np.array(bonafide_scores), np.array(spoof_scores))
eer_percentage = eer * 100
responses.append({"EER": f"{eer_percentage:.2f}%"})
logger.info(f"Calculated EER: {eer_percentage:.2f}%")
else:
logger.info("Not enough data to calculate EER.")
return responses
# Exécuter le serveur FastAPI dans Colab
nest_asyncio.apply()
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
proxy_headers=True,
forwarded_allow_ips="*"
)
|