import json import torch import torchaudio import numpy as np from torch import nn import torch.nn.functional as F from fastapi import FastAPI, UploadFile, HTTPException, File import nest_asyncio import uvicorn from model_utils import ( Model, ) # Assurez-vous que la classe Model est correctement importée import os import soundfile as sf import io import tempfile from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from typing import List, Union from calculate_modules import compute_eer # Importer la fonction compute_eer import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # # App FastAPI app = FastAPI() # Serve static frontend files app.mount("/static", StaticFiles(directory="Web"), name="static") @app.get("/") def home(): return FileResponse("Web/index.html") # Configurer CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # Autorise toutes les origines (peut être restreint) allow_credentials=True, allow_methods=["*"], # Autorise toutes les méthodes HTTP allow_headers=["*"], # Autorise tous les en-têtes ) # Charger la configuration du modèle def load_config(config_path): try: with open(config_path, "r") as f: return json.load(f) except Exception as e: print(f"Erreur lors du chargement de la configuration : {e}") raise HTTPException( status_code=500, detail=f"Erreur lors du chargement de la configuration: {e}", ) # Charger le modèle def load_model(checkpoint_path, d_args): models_list=[] for i in range(len(checkpoint_path)): model = Model(d_args) try: # Load checkpoint checkpoint = torch.load(checkpoint_path[i], map_location=torch.device("cpu")) model.load_state_dict(checkpoint) print(f"Model_{i} loaded successfully.") except Exception as e: print(f"Error loading model_{i}: {e}") raise model.eval() models_list.append(model) return models_list # Prétraiter l'audio def preprocess_audio(audio_path, sample_rate=16000): try: print(f"Chargement de l'audio: {audio_path}") waveform, sr = torchaudio.load(audio_path) print(f"Audio chargé: {audio_path}, Taux d'échantillonnage: {sr}") if sr != sample_rate: resample_transform = torchaudio.transforms.Resample( orig_freq=sr, new_freq=sample_rate ) waveform = resample_transform(waveform) if waveform.size(0) > 1: waveform = torch.mean( waveform, dim=0, keepdim=True ) # Convertir en mono si stéréo return waveform except Exception as e: print(f"Erreur dans le prétraitement audio : {e}") raise HTTPException( status_code=500, detail=f"Erreur dans le prétraitement de l'audio: {e}" ) def infer(model_list, waveform, freq_aug=False): try: with torch.no_grad(): probabilities_sum=None for model in model_list: last_hidden, output = model(waveform, Freq_aug=freq_aug) print("Sortie du modèle:", output) if output is None: raise ValueError("La sortie du modèle est nulle.") probabilities = F.softmax(output, dim=1) if probabilities_sum is None: probabilities_sum=probabilities else: probabilities_sum+=np.array(probabilities[0].tolist()) probabilities_sum=probabilities_sum/len(model_list) predicted_label = torch.argmax(probabilities_sum, dim=1).item() confidence = probabilities_sum[ 0 ].tolist() # Liste des probabilités pour toutes les classes max_confidence = 1 - max(confidence) # La probabilité la plus élevée return ( predicted_label, max_confidence, ) # Retourner également la probabilité la plus élevée except Exception as e: print(f"Erreur pendant l'inférence : {e}") raise # Charger le modèle d'exemple config_path = "./AASIST_ASVspoof5_Exp4_CL.conf" # Remplacez par le chemin réel de votre fichier de config config = load_config(config_path) d_args = config["model_config"] checkpoint_path = [f"./S{i+1}_best.pth" for i in range(4)] # Remplacez par votre checkpoint model_list = load_model(checkpoint_path, d_args) @app.post("/predict/") async def predict(files: List[UploadFile] = File(...)): """ Endpoint to handle batch inference for multiple audio files. Accepts a list of audio files and returns inference results for each file. """ responses = [] bonafide_scores = [] spoof_scores = [] for file in files: try: logger.info(f"Processing file: {file.filename}") # Validate file format if not file.filename.endswith((".wav", ".flac")): raise HTTPException( status_code=400, detail="Invalid file format. Only .wav and .flac files are allowed.", ) # Save the uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: temp_audio_path = temp_audio.name temp_audio.write(await file.read()) # Preprocess the audio file waveform = preprocess_audio(temp_audio_path) rms = torch.sqrt(torch.mean(waveform**2)).item() max_eng = torch.sqrt(torch.max(waveform**2)).item() volume = "normal" if rms < 0.01: volume = "silent" elif rms < 0.05 and max_eng<0.3: volume = "augmented" waveform = waveform * (1/(3*rms)) # Apply gain boost logger.info(f"Volume state for {file.filename}: rms={rms},max_eng={max_eng} > {volume}") # Perform inference label, confidence = infer(model_list, waveform) # Store scores for EER calculation if label == 0: # Bonafide bonafide_scores.append(confidence) else: # Spoof spoof_scores.append(confidence) # Prepare the response response = { "filename": file.filename, "label": "Genuine" if label == 0 else "Spoof", "confidence": confidence, "status": "success", "volume":volume } responses.append(response) # Clean up the temporary file os.unlink(temp_audio_path) except Exception as e: responses.append( {"filename": file.filename, "error": str(e), "status": "failed"} ) # Log collected scores for debugging logger.info(f"Bonafide scores: {bonafide_scores}") logger.info(f"Spoof scores: {spoof_scores}") # Calculate EER if we have scores for both bonafide and spoof if bonafide_scores and spoof_scores: eer, _, _, _ = compute_eer(np.array(bonafide_scores), np.array(spoof_scores)) eer_percentage = eer * 100 responses.append({"EER": f"{eer_percentage:.2f}%"}) logger.info(f"Calculated EER: {eer_percentage:.2f}%") else: logger.info("Not enough data to calculate EER.") return responses # Exécuter le serveur FastAPI dans Colab nest_asyncio.apply() import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, proxy_headers=True, forwarded_allow_ips="*" )