Mrkomiljon commited on
Commit
288c4d9
·
verified ·
1 Parent(s): f786f1d

Upload 5 files

Browse files

some files updated

Files changed (5) hide show
  1. RawNet_model.onnx +3 -0
  2. app.py +114 -0
  3. best_model.pth +3 -0
  4. data_utils.py +94 -0
  5. inference_onnx.py +72 -0
RawNet_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64e9e09f132ecb8d4a4fc60ec29fab2a35e3b4cd8605e5489ba3a5d085d143e2
3
+ size 70911020
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from fastapi import FastAPI, File, UploadFile
6
+ from model import RawNet
7
+ from data_utils import pad # Import the pad function from data_utils.py
8
+ import yaml
9
+ import torch.nn.functional as F # For softmax
10
+ from fastapi.responses import JSONResponse
11
+ from tempfile import NamedTemporaryFile
12
+ import uvicorn
13
+ import webbrowser
14
+
15
+ # Initialize FastAPI app
16
+ app = FastAPI()
17
+
18
+ # Load the model
19
+ model_config_path = 'C:\\\\Users\\\\GOOD\\\\Desktop\\\\TEST-2024\\\\2021\\\\LA\\\\Baseline-RawNet2\\\\model_config_RawNet.yaml'
20
+ model_path = 'C:\\\\Users\\\\GOOD\\\\Desktop\\\\TEST-2024\\\\2021\\\\LA\\\\Baseline-RawNet2\\\\checkpoints\\\\best_model.pth'
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+
23
+ with open(model_config_path, 'r') as f:
24
+ model_config = yaml.safe_load(f)
25
+
26
+ model = RawNet(model_config['model'], device).to(device)
27
+ model.load_state_dict(torch.load(model_path, map_location=device))
28
+ model.eval()
29
+
30
+
31
+ def preprocess_audio_segment(segment, cut=64600):
32
+ """
33
+ Preprocess a single audio segment: pad or trim as required.
34
+ """
35
+ if len(segment) < cut:
36
+ segment = pad(segment, max_len=cut) # Pad if shorter
37
+ else:
38
+ segment = segment[:cut] # Trim if longer
39
+ return torch.tensor(segment, dtype=torch.float32).unsqueeze(0) # Add batch dimension
40
+
41
+
42
+ def predict_with_sliding_window(waveform, model, device, window_size=64600, step_size=64600, sample_rate=16000):
43
+ """
44
+ Use a sliding window to predict if the audio is real or fake over the entire audio.
45
+ """
46
+ total_segments = []
47
+ total_probabilities = []
48
+
49
+ # Sliding window processing
50
+ for start in range(0, len(waveform), step_size):
51
+ end = start + window_size
52
+ segment = waveform[start:end]
53
+
54
+ # Preprocess the segment
55
+ audio_tensor = preprocess_audio_segment(segment).to(device)
56
+
57
+ # Perform inference
58
+ with torch.no_grad():
59
+ output = model(audio_tensor)
60
+ probabilities = F.softmax(output, dim=1) # Compute probabilities
61
+ prediction = torch.argmax(probabilities, dim=1)
62
+
63
+ # Store the results
64
+ predicted_class = "Human voice" if prediction.item() == 1 else "AI generated voice (TTS)"
65
+ probability = probabilities[0, prediction.item()].item() * 100
66
+ total_segments.append(predicted_class)
67
+ total_probabilities.append(probability)
68
+
69
+ # Final aggregation
70
+ majority_class = max(set(total_segments), key=total_segments.count) # Majority voting
71
+ avg_probability = np.mean(total_probabilities) # Average probability
72
+
73
+ return majority_class, avg_probability
74
+
75
+
76
+ @app.post("/predict")
77
+ async def predict_audio(file: UploadFile = File(...)):
78
+ """
79
+ Endpoint to process audio and predict using the RawNet model.
80
+ """
81
+ try:
82
+ # Save uploaded file to a temporary file
83
+ with NamedTemporaryFile(delete=False) as temp_file:
84
+ temp_file.write(await file.read())
85
+ temp_filename = temp_file.name
86
+
87
+ # Load audio file
88
+ waveform, _ = librosa.load(temp_filename, sr=16000)
89
+
90
+ # Perform prediction
91
+ result, avg_probability = predict_with_sliding_window(waveform, model, device)
92
+
93
+ # Clean up temporary file
94
+ os.remove(temp_filename)
95
+
96
+ return JSONResponse({
97
+ "Your audio": result,
98
+ "average_probability": f"{avg_probability:.2f}%"
99
+ })
100
+
101
+ except Exception as e:
102
+ return JSONResponse({"error": str(e)}, status_code=500)
103
+
104
+
105
+ @app.get("/")
106
+ async def root():
107
+ return {"message": "RawNet Sliding Window Prediction API"}
108
+
109
+ # Automatically open docs or print URL when server starts
110
+ if __name__ == "__main__":
111
+ url = "http://127.0.0.1:8000/docs"
112
+ print(f"API docs available at: {url}")
113
+ webbrowser.open(url) # Open in the default browser
114
+ uvicorn.run(app, host="127.0.0.1", port=8000)
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:940acc620962f2ce0e2b1f91c3c514bc9128240b5800612205aaead7b78c1c64
3
+ size 70532085
data_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch import Tensor
5
+ import librosa
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ # Audioni padding qilish
10
+ def pad(x, max_len=64600):
11
+ x_len = x.shape[0]
12
+ if x_len >= max_len:
13
+ return x[:max_len]
14
+ # Padding kerak
15
+ num_repeats = (max_len // x_len) + 1
16
+ padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
17
+ return padded_x
18
+
19
+ def genSpoof_list(dir_meta, is_train=False, is_eval=False):
20
+ d_meta = {}
21
+ file_list = []
22
+ with open(dir_meta, 'r') as f:
23
+ l_meta = f.readlines()
24
+
25
+ if is_train:
26
+ for line in l_meta:
27
+ _, key, _, _, label = line.strip().split(' ')
28
+ file_list.append(key)
29
+ d_meta[key] = 1 if label == 'bonafide' else 0
30
+ return d_meta, file_list
31
+ elif is_eval:
32
+ for line in l_meta:
33
+ key = line.strip()
34
+ file_list.append(key)
35
+ return file_list
36
+ else:
37
+ for line in l_meta:
38
+ _, key, _, _, label = line.strip().split(' ')
39
+ file_list.append(key)
40
+ d_meta[key] = 1 if label == 'bonafide' else 0
41
+ return d_meta, file_list
42
+
43
+ class Dataset_ASVspoof2019_train(Dataset):
44
+ def __init__(self, list_IDs, labels, base_dir, cut=64600):
45
+ """
46
+ Args:
47
+ list_IDs: Utts kalitlari ro'yxati (string).
48
+ labels: Kalitlar va tegishli yorliqlar lug'ati.
49
+ base_dir: Ma'lumotlar joylashgan katalog (flac katalogsiz).
50
+ cut: Maksimal uzunlik (standart: 64600).
51
+ """
52
+ self.list_IDs = list_IDs
53
+ self.labels = labels
54
+ self.base_dir = base_dir
55
+ self.cut = cut
56
+
57
+ def __len__(self):
58
+ return len(self.list_IDs)
59
+
60
+ def __getitem__(self, index):
61
+ key = self.list_IDs[index]
62
+ file_path = os.path.join(self.base_dir, f"{key}.flac") # flac ni qayta qo‘shmang
63
+ if not os.path.exists(file_path):
64
+ raise FileNotFoundError(f"File not found: {file_path}")
65
+
66
+ X, fs = librosa.load(file_path, sr=16000)
67
+ X_pad = pad(X, self.cut)
68
+ x_inp = Tensor(X_pad)
69
+ y = self.labels[key]
70
+ return x_inp, y
71
+
72
+
73
+ # ASVspoof2021 baholash ma'lumotlar to'plami uchun Dataset sinfi
74
+ class Dataset_ASVspoof2021_eval(Dataset):
75
+ def __init__(self, list_IDs, base_dir, cut=64600):
76
+ self.list_IDs = [x.replace(' ', '_') for x in list_IDs] # Bo'sh joylarni almashtirish
77
+ self.base_dir = base_dir
78
+ self.cut = cut
79
+
80
+ def __len__(self):
81
+ return len(self.list_IDs)
82
+
83
+ def __getitem__(self, index):
84
+ key = self.list_IDs[index]
85
+ file_path = os.path.join(self.base_dir, f"{key}.flac")
86
+ if not os.path.exists(file_path):
87
+ print(f"Checking file: {file_path}, Exists: {os.path.exists(file_path)}") # Fayl mavjudligini tekshirish
88
+ raise FileNotFoundError(f"File not found: {file_path}")
89
+
90
+ X, fs = librosa.load(file_path, sr=16000)
91
+ X_pad = pad(X, self.cut)
92
+ x_inp = Tensor(X_pad)
93
+ return x_inp, key
94
+
inference_onnx.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ import torch
6
+ from data_utils import pad # Import the pad function from data_utils.py
7
+
8
+ # Preprocess audio for a single segment
9
+ def preprocess_audio_segment(segment, cut=64600):
10
+ """
11
+ Preprocess a single audio segment: pad or trim as required.
12
+ """
13
+ if len(segment) < cut:
14
+ segment = pad(segment, max_len=cut) # Pad if shorter
15
+ else:
16
+ segment = segment[:cut] # Trim if longer
17
+ return np.expand_dims(np.array(segment, dtype=np.float32), axis=0) # Add batch dimension
18
+
19
+ # Perform sliding window prediction
20
+ def predict_with_sliding_window(audio_path, onnx_model_path, window_size=64600, step_size=64600, sample_rate=16000):
21
+ """
22
+ Use a sliding window to predict if the audio is real or fake over the entire audio.
23
+ """
24
+ # Load the ONNX runtime session
25
+ ort_session = ort.InferenceSession(onnx_model_path)
26
+
27
+ # Load the audio file
28
+ waveform, _ = librosa.load(audio_path, sr=sample_rate)
29
+ total_segments = []
30
+ total_probabilities = []
31
+
32
+ # Sliding window processing
33
+ for start in range(0, len(waveform), step_size):
34
+ end = start + window_size
35
+ segment = waveform[start:end]
36
+
37
+ # Preprocess the segment
38
+ audio_tensor = preprocess_audio_segment(segment)
39
+
40
+ # Perform inference
41
+ inputs = {ort_session.get_inputs()[0].name: audio_tensor}
42
+ outputs = ort_session.run(None, inputs)
43
+ probabilities = torch.tensor(outputs[0]) # Convert to torch tensor for processing
44
+ probabilities = torch.nn.functional.softmax(probabilities, dim=1) # Compute probabilities
45
+ prediction = torch.argmax(probabilities, dim=1)
46
+
47
+ # Store the results
48
+ predicted_class = "Real" if prediction.item() == 1 else "Fake"
49
+ probability = probabilities[0, prediction.item()].item() * 100
50
+ total_segments.append(predicted_class)
51
+ total_probabilities.append(probability)
52
+
53
+ print(f"Segment {start//step_size + 1}: {predicted_class}, Probability: {probability:.2f}%")
54
+
55
+ # Final aggregation
56
+ majority_class = max(set(total_segments), key=total_segments.count) # Majority voting
57
+ avg_probability = np.mean(total_probabilities) # Average probability
58
+
59
+ return majority_class, avg_probability
60
+
61
+ # Main script for inference
62
+ if __name__ == "__main__":
63
+ # Path to the ONNX model
64
+ onnx_model_path = 'C:\\Users\\GOOD\\Desktop\\TEST-2024\\2021\\LA\\Baseline-RawNet2\\checkpoints\\RawNet_model.onnx'
65
+
66
+ # Specify the path to the audio file
67
+ audio_path = "C:\\Users\\GOOD\\Desktop\\TEST-2024\\2021\\LA\\Baseline-RawNet2\\audio\\KTA.mp3" # Example .mp3 file
68
+
69
+ # Perform sliding window prediction
70
+ result, avg_probability = predict_with_sliding_window(audio_path, onnx_model_path)
71
+
72
+ print(f"Final Result: {result}, Average Probability: {avg_probability:.2f}%")