MeysamSh commited on
Commit
56053bc
·
1 Parent(s): 640285d

Add application file

Browse files
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. app.py +195 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as the base image
2
+ FROM python:3.11-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the requirements file into the container
8
+ COPY requirements.txt .
9
+
10
+ # Install the Python dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy the rest of the application code
14
+ COPY . .
15
+
16
+ # Expose the port the app runs on
17
+ EXPOSE 7860
18
+
19
+ # Command to run the application
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from fastapi import FastAPI, UploadFile, HTTPException, File
8
+ import nest_asyncio
9
+ import uvicorn
10
+ from model_utils import (
11
+ Model,
12
+ ) # Assurez-vous que la classe Model est correctement importée
13
+ import os
14
+ import soundfile as sf
15
+ import io
16
+ import tempfile
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from typing import List, Union
19
+ from calculate_modules import compute_eer # Importer la fonction compute_eer
20
+ import logging
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__) #
24
+ # App FastAPI
25
+ app = FastAPI()
26
+
27
+ # Configurer CORS
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"], # Autorise toutes les origines (peut être restreint)
31
+ allow_credentials=True,
32
+ allow_methods=["*"], # Autorise toutes les méthodes HTTP
33
+ allow_headers=["*"], # Autorise tous les en-têtes
34
+ )
35
+
36
+
37
+ # Charger la configuration du modèle
38
+ def load_config(config_path):
39
+ try:
40
+ with open(config_path, "r") as f:
41
+ return json.load(f)
42
+ except Exception as e:
43
+ print(f"Erreur lors du chargement de la configuration : {e}")
44
+ raise HTTPException(
45
+ status_code=500,
46
+ detail=f"Erreur lors du chargement de la configuration: {e}",
47
+ )
48
+
49
+
50
+ # Charger le modèle
51
+ def load_model(checkpoint_path, d_args):
52
+ model = Model(d_args)
53
+ try:
54
+ # Load checkpoint
55
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
56
+ model.load_state_dict(checkpoint)
57
+ print("Model loaded successfully.")
58
+ except Exception as e:
59
+ print(f"Error loading model: {e}")
60
+ raise
61
+ model.eval()
62
+ return model
63
+
64
+
65
+ # Prétraiter l'audio
66
+ def preprocess_audio(audio_path, sample_rate=16000):
67
+ try:
68
+ print(f"Chargement de l'audio: {audio_path}")
69
+ waveform, sr = torchaudio.load(audio_path)
70
+ print(f"Audio chargé: {audio_path}, Taux d'échantillonnage: {sr}")
71
+ if sr != sample_rate:
72
+ resample_transform = torchaudio.transforms.Resample(
73
+ orig_freq=sr, new_freq=sample_rate
74
+ )
75
+ waveform = resample_transform(waveform)
76
+ if waveform.size(0) > 1:
77
+ waveform = torch.mean(
78
+ waveform, dim=0, keepdim=True
79
+ ) # Convertir en mono si stéréo
80
+ return waveform
81
+ except Exception as e:
82
+ print(f"Erreur dans le prétraitement audio : {e}")
83
+ raise HTTPException(
84
+ status_code=500, detail=f"Erreur dans le prétraitement de l'audio: {e}"
85
+ )
86
+
87
+
88
+ def infer(model, waveform, freq_aug=False):
89
+ try:
90
+ with torch.no_grad():
91
+ last_hidden, output = model(waveform, Freq_aug=freq_aug)
92
+ print("Sortie du modèle:", output)
93
+ if output is None:
94
+ raise ValueError("La sortie du modèle est nulle.")
95
+ probabilities = F.softmax(output, dim=1)
96
+ predicted_label = torch.argmax(probabilities, dim=1).item()
97
+ confidence = probabilities[
98
+ 0
99
+ ].tolist() # Liste des probabilités pour toutes les classes
100
+ max_confidence = 1 - max(confidence) # La probabilité la plus élevée
101
+ return (
102
+ predicted_label,
103
+ max_confidence,
104
+ ) # Retourner également la probabilité la plus élevée
105
+ except Exception as e:
106
+ print(f"Erreur pendant l'inférence : {e}")
107
+ raise
108
+
109
+
110
+ # Charger le modèle d'exemple
111
+ config_path = "./AASIST_ASVspoof5_Exp4_CL.conf" # Remplacez par le chemin réel de votre fichier de config
112
+ config = load_config(config_path)
113
+ d_args = config["model_config"]
114
+ checkpoint_path = (
115
+ "./Ex4_CLspeaker_sampler_eer0.164.pth" # Remplacez par votre checkpoint
116
+ )
117
+ model = load_model(checkpoint_path, d_args)
118
+
119
+
120
+ @app.post("/predict/")
121
+ async def predict(files: List[UploadFile] = File(...)):
122
+ """
123
+ Endpoint to handle batch inference for multiple audio files.
124
+ Accepts a list of audio files and returns inference results for each file.
125
+ """
126
+ responses = []
127
+ bonafide_scores = []
128
+ spoof_scores = []
129
+
130
+ for file in files:
131
+ try:
132
+ logger.info(f"Processing file: {file.filename}")
133
+
134
+ # Validate file format
135
+ if not file.filename.endswith((".wav", ".flac")):
136
+ raise HTTPException(
137
+ status_code=400,
138
+ detail="Invalid file format. Only .wav and .flac files are allowed.",
139
+ )
140
+
141
+ # Save the uploaded file temporarily
142
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
143
+ temp_audio_path = temp_audio.name
144
+ temp_audio.write(await file.read())
145
+
146
+ # Preprocess the audio file
147
+ waveform = preprocess_audio(temp_audio_path)
148
+
149
+ # Perform inference
150
+ label, confidence = infer(model, waveform)
151
+
152
+ # Store scores for EER calculation
153
+ if label == 0: # Bonafide
154
+ bonafide_scores.append(confidence)
155
+ else: # Spoof
156
+ spoof_scores.append(confidence)
157
+
158
+ # Prepare the response
159
+ response = {
160
+ "filename": file.filename,
161
+ "label": "Genuine" if label == 0 else "Spoof",
162
+ "confidence": confidence,
163
+ "status": "success",
164
+ }
165
+ responses.append(response)
166
+
167
+ # Clean up the temporary file
168
+ os.unlink(temp_audio_path)
169
+
170
+ except Exception as e:
171
+ responses.append(
172
+ {"filename": file.filename, "error": str(e), "status": "failed"}
173
+ )
174
+
175
+ # Log collected scores for debugging
176
+ logger.info(f"Bonafide scores: {bonafide_scores}")
177
+ logger.info(f"Spoof scores: {spoof_scores}")
178
+
179
+ # Calculate EER if we have scores for both bonafide and spoof
180
+ if bonafide_scores and spoof_scores:
181
+ eer, _, _, _ = compute_eer(np.array(bonafide_scores), np.array(spoof_scores))
182
+ eer_percentage = eer * 100
183
+ responses.append({"EER": f"{eer_percentage:.2f}%"})
184
+ logger.info(f"Calculated EER: {eer_percentage:.2f}%")
185
+ else:
186
+ logger.info("Not enough data to calculate EER.")
187
+
188
+ return responses
189
+
190
+
191
+ # Exécuter le serveur FastAPI dans Colab
192
+ nest_asyncio.apply()
193
+ import uvicorn
194
+
195
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.68.0
2
+ uvicorn==0.15.0
3
+ torch==2.0.0
4
+ torchaudio==2.0.0
5
+ numpy<2
6
+ nest-asyncio==1.5.1
7
+ pydub==0.25.1
8
+ soundfile
9
+ python-multipart