Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, UploadFile, HTTPException, File | |
| import nest_asyncio | |
| import uvicorn | |
| from model_utils import ( | |
| Model, | |
| ) # Assurez-vous que la classe Model est correctement importée | |
| import os | |
| import soundfile as sf | |
| import io | |
| import tempfile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from typing import List, Union | |
| from calculate_modules import compute_eer # Importer la fonction compute_eer | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) # | |
| # App FastAPI | |
| app = FastAPI() | |
| # Serve static frontend files | |
| app.mount("/static", StaticFiles(directory="Web"), name="static") | |
| def home(): | |
| return FileResponse("Web/index.html") | |
| # Configurer CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Autorise toutes les origines (peut être restreint) | |
| allow_credentials=True, | |
| allow_methods=["*"], # Autorise toutes les méthodes HTTP | |
| allow_headers=["*"], # Autorise tous les en-têtes | |
| ) | |
| # Charger la configuration du modèle | |
| def load_config(config_path): | |
| try: | |
| with open(config_path, "r") as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Erreur lors du chargement de la configuration : {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Erreur lors du chargement de la configuration: {e}", | |
| ) | |
| # Charger le modèle | |
| def load_model(checkpoint_path, d_args): | |
| models_list=[] | |
| for i in range(len(checkpoint_path)): | |
| model = Model(d_args) | |
| try: | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path[i], map_location=torch.device("cpu")) | |
| model.load_state_dict(checkpoint) | |
| print(f"Model_{i} loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model_{i}: {e}") | |
| raise | |
| model.eval() | |
| models_list.append(model) | |
| return models_list | |
| # Prétraiter l'audio | |
| def preprocess_audio(audio_path, sample_rate=16000): | |
| try: | |
| print(f"Chargement de l'audio: {audio_path}") | |
| waveform, sr = torchaudio.load(audio_path) | |
| print(f"Audio chargé: {audio_path}, Taux d'échantillonnage: {sr}") | |
| if sr != sample_rate: | |
| resample_transform = torchaudio.transforms.Resample( | |
| orig_freq=sr, new_freq=sample_rate | |
| ) | |
| waveform = resample_transform(waveform) | |
| if waveform.size(0) > 1: | |
| waveform = torch.mean( | |
| waveform, dim=0, keepdim=True | |
| ) # Convertir en mono si stéréo | |
| return waveform | |
| except Exception as e: | |
| print(f"Erreur dans le prétraitement audio : {e}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Erreur dans le prétraitement de l'audio: {e}" | |
| ) | |
| def infer(model_list, waveform, freq_aug=False): | |
| try: | |
| with torch.no_grad(): | |
| probabilities_sum=None | |
| for model in model_list: | |
| last_hidden, output = model(waveform, Freq_aug=freq_aug) | |
| print("Sortie du modèle:", output) | |
| if output is None: | |
| raise ValueError("La sortie du modèle est nulle.") | |
| probabilities = F.softmax(output, dim=1) | |
| if probabilities_sum is None: | |
| probabilities_sum=probabilities | |
| else: | |
| probabilities_sum+=np.array(probabilities[0].tolist()) | |
| probabilities_sum=probabilities_sum/len(model_list) | |
| predicted_label = torch.argmax(probabilities_sum, dim=1).item() | |
| confidence = probabilities_sum[ | |
| 0 | |
| ].tolist() # Liste des probabilités pour toutes les classes | |
| max_confidence = 1 - max(confidence) # La probabilité la plus élevée | |
| return ( | |
| predicted_label, | |
| max_confidence, | |
| ) # Retourner également la probabilité la plus élevée | |
| except Exception as e: | |
| print(f"Erreur pendant l'inférence : {e}") | |
| raise | |
| # Charger le modèle d'exemple | |
| config_path = "./AASIST_ASVspoof5_Exp4_CL.conf" # Remplacez par le chemin réel de votre fichier de config | |
| config = load_config(config_path) | |
| d_args = config["model_config"] | |
| checkpoint_path = [f"./S{i+1}_best.pth" for i in range(4)] # Remplacez par votre checkpoint | |
| model_list = load_model(checkpoint_path, d_args) | |
| async def predict(files: List[UploadFile] = File(...)): | |
| """ | |
| Endpoint to handle batch inference for multiple audio files. | |
| Accepts a list of audio files and returns inference results for each file. | |
| """ | |
| responses = [] | |
| bonafide_scores = [] | |
| spoof_scores = [] | |
| for file in files: | |
| try: | |
| logger.info(f"Processing file: {file.filename}") | |
| # Validate file format | |
| if not file.filename.endswith((".wav", ".flac")): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid file format. Only .wav and .flac files are allowed.", | |
| ) | |
| # Save the uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
| temp_audio_path = temp_audio.name | |
| temp_audio.write(await file.read()) | |
| # Preprocess the audio file | |
| waveform = preprocess_audio(temp_audio_path) | |
| rms = torch.sqrt(torch.mean(waveform**2)).item() | |
| max_eng = torch.sqrt(torch.max(waveform**2)).item() | |
| volume = "normal" | |
| if rms < 0.01: | |
| volume = "silent" | |
| elif rms < 0.05 and max_eng<0.3: | |
| volume = "augmented" | |
| waveform = waveform * (1/(3*rms)) # Apply gain boost | |
| logger.info(f"Volume state for {file.filename}: rms={rms},max_eng={max_eng} > {volume}") | |
| # Perform inference | |
| label, confidence = infer(model_list, waveform) | |
| # Store scores for EER calculation | |
| if label == 0: # Bonafide | |
| bonafide_scores.append(confidence) | |
| else: # Spoof | |
| spoof_scores.append(confidence) | |
| # Prepare the response | |
| response = { | |
| "filename": file.filename, | |
| "label": "Genuine" if label == 0 else "Spoof", | |
| "confidence": confidence, | |
| "status": "success", | |
| "volume":volume | |
| } | |
| responses.append(response) | |
| # Clean up the temporary file | |
| os.unlink(temp_audio_path) | |
| except Exception as e: | |
| responses.append( | |
| {"filename": file.filename, "error": str(e), "status": "failed"} | |
| ) | |
| # Log collected scores for debugging | |
| logger.info(f"Bonafide scores: {bonafide_scores}") | |
| logger.info(f"Spoof scores: {spoof_scores}") | |
| # Calculate EER if we have scores for both bonafide and spoof | |
| if bonafide_scores and spoof_scores: | |
| eer, _, _, _ = compute_eer(np.array(bonafide_scores), np.array(spoof_scores)) | |
| eer_percentage = eer * 100 | |
| responses.append({"EER": f"{eer_percentage:.2f}%"}) | |
| logger.info(f"Calculated EER: {eer_percentage:.2f}%") | |
| else: | |
| logger.info("Not enough data to calculate EER.") | |
| return responses | |
| # Exécuter le serveur FastAPI dans Colab | |
| nest_asyncio.apply() | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| proxy_headers=True, | |
| forwarded_allow_ips="*" | |
| ) | |