Update pipeline.py
Browse files- pipeline.py +90 -78
pipeline.py
CHANGED
|
@@ -11,8 +11,7 @@ from facenet_pytorch import MTCNN
|
|
| 11 |
from rawnet import RawNet
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
#Set random seed for reproducibility.
|
| 16 |
tf.random.set_seed(42)
|
| 17 |
|
| 18 |
# Extract model if not already extracted
|
|
@@ -24,9 +23,7 @@ if not os.path.exists("efficientnet-b0"):
|
|
| 24 |
zip_ref.close()
|
| 25 |
print("Model extracted successfully!")
|
| 26 |
|
| 27 |
-
# Load
|
| 28 |
-
# Load model without compiling to avoid optimizer dependency issues
|
| 29 |
-
# Load model using TFSMLayer (Keras 3 compatible)
|
| 30 |
model = tf.keras.layers.TFSMLayer(
|
| 31 |
"efficientnet-b0/",
|
| 32 |
call_endpoint="serving_default"
|
|
@@ -37,18 +34,15 @@ def convert_to_mp4(input_path):
|
|
| 37 |
"""
|
| 38 |
Convert any video (e.g. .webm from webcam) to .mp4 using ffmpeg.
|
| 39 |
Returns the path to the converted file, or the original path if already mp4.
|
| 40 |
-
The caller is responsible for deleting the temp file when done.
|
| 41 |
"""
|
| 42 |
ext = os.path.splitext(input_path)[-1].lower()
|
| 43 |
if ext == ".mp4":
|
| 44 |
-
# Already mp4 β verify OpenCV can actually open it
|
| 45 |
cap = cv2.VideoCapture(input_path)
|
| 46 |
ok = cap.isOpened()
|
| 47 |
cap.release()
|
| 48 |
if ok:
|
| 49 |
-
return input_path, False
|
| 50 |
|
| 51 |
-
# Write to a named temp file so OpenCV can open it by path
|
| 52 |
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
| 53 |
tmp.close()
|
| 54 |
output_path = tmp.name
|
|
@@ -65,16 +59,14 @@ def convert_to_mp4(input_path):
|
|
| 65 |
result = subprocess.run(cmd, capture_output=True)
|
| 66 |
if result.returncode != 0:
|
| 67 |
os.unlink(output_path)
|
| 68 |
-
raise RuntimeError(
|
| 69 |
-
|
| 70 |
-
)
|
| 71 |
-
return output_path, True # (path, is_temp)
|
| 72 |
|
| 73 |
|
| 74 |
class DetectionPipeline:
|
| 75 |
"""Pipeline class for detecting faces in the frames of a video file."""
|
| 76 |
|
| 77 |
-
def __init__(self, n_frames=None, batch_size=60, resize=None, input_modality
|
| 78 |
self.n_frames = n_frames
|
| 79 |
self.batch_size = batch_size
|
| 80 |
self.resize = resize
|
|
@@ -83,9 +75,6 @@ class DetectionPipeline:
|
|
| 83 |
def __call__(self, filename):
|
| 84 |
if self.input_modality == 'video':
|
| 85 |
print('Input modality is video.')
|
| 86 |
-
|
| 87 |
-
# BUG FIX: Webcam recordings from Gradio arrive as .webm (VP8/VP9).
|
| 88 |
-
# OpenCV has no WebM support in headless builds β convert to .mp4 first.
|
| 89 |
converted_path, is_temp = convert_to_mp4(filename)
|
| 90 |
print(f"Processing video: {converted_path} (converted={is_temp})")
|
| 91 |
|
|
@@ -112,98 +101,88 @@ class DetectionPipeline:
|
|
| 112 |
if not success:
|
| 113 |
continue
|
| 114 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 115 |
-
|
| 116 |
if self.resize is not None:
|
| 117 |
frame = frame.resize([int(d * self.resize) for d in frame.size])
|
| 118 |
frames.append(frame)
|
| 119 |
-
|
| 120 |
if len(frames) % self.batch_size == 0 or j == sample[-1]:
|
| 121 |
face2 = cv2.resize(frame, (224, 224))
|
| 122 |
faces.append(face2)
|
| 123 |
|
| 124 |
v_cap.release()
|
| 125 |
finally:
|
| 126 |
-
# Clean up the temp converted file
|
| 127 |
if is_temp and os.path.exists(converted_path):
|
| 128 |
os.unlink(converted_path)
|
| 129 |
|
| 130 |
if len(faces) == 0:
|
| 131 |
raise RuntimeError("No frames could be extracted from the video.")
|
| 132 |
-
|
| 133 |
return faces
|
| 134 |
|
| 135 |
elif self.input_modality == 'image':
|
| 136 |
print('Input modality is image.')
|
| 137 |
-
print('Reading image')
|
| 138 |
image = cv2.cvtColor(filename, cv2.COLOR_BGR2RGB)
|
| 139 |
image = cv2.resize(image, (224, 224))
|
| 140 |
return image
|
| 141 |
-
|
| 142 |
elif self.input_modality == 'audio':
|
| 143 |
print("Input modality is audio.")
|
| 144 |
x, sr = librosa.load(filename)
|
| 145 |
x_pt = torch.Tensor(x)
|
| 146 |
x_pt = torch.unsqueeze(x_pt, dim=0)
|
| 147 |
return x_pt
|
| 148 |
-
|
| 149 |
else:
|
| 150 |
raise ValueError("Invalid input modality. Must be either 'video' or 'image'")
|
| 151 |
|
|
|
|
| 152 |
detection_video_pipeline = DetectionPipeline(n_frames=5, batch_size=1, input_modality='video')
|
| 153 |
detection_image_pipeline = DetectionPipeline(batch_size=1, input_modality='image')
|
| 154 |
|
| 155 |
-
def deepfakes_video_predict(input_video):
|
| 156 |
|
|
|
|
| 157 |
faces = detection_video_pipeline(input_video)
|
| 158 |
total = 0
|
| 159 |
real_res = []
|
| 160 |
fake_res = []
|
| 161 |
|
| 162 |
for face in faces:
|
| 163 |
-
|
| 164 |
-
face2 = face/255
|
| 165 |
pred = model(np.expand_dims(face2, axis=0))
|
| 166 |
pred = list(pred.values())[0].numpy()[0]
|
| 167 |
-
|
| 168 |
real, fake = pred[0], pred[1]
|
| 169 |
real_res.append(real)
|
| 170 |
fake_res.append(fake)
|
| 171 |
-
|
| 172 |
-
total+=1
|
| 173 |
-
|
| 174 |
pred2 = pred[1]
|
| 175 |
-
|
| 176 |
if pred2 > 0.5:
|
| 177 |
-
|
| 178 |
else:
|
| 179 |
-
|
|
|
|
| 180 |
real_mean = np.mean(real_res)
|
| 181 |
fake_mean = np.mean(fake_res)
|
| 182 |
print(f"Real Faces: {real_mean}")
|
| 183 |
print(f"Fake Faces: {fake_mean}")
|
| 184 |
-
text = ""
|
| 185 |
|
| 186 |
if real_mean >= 0.5:
|
| 187 |
-
text = "The video is REAL. \n Deepfakes Confidence: " + str(round(100 - (real_mean*100), 3)) + "%"
|
| 188 |
else:
|
| 189 |
-
text = "The video is FAKE. \n Deepfakes Confidence: " + str(round(fake_mean*100, 3)) + "%"
|
| 190 |
-
|
| 191 |
return text
|
| 192 |
|
| 193 |
|
| 194 |
def deepfakes_image_predict(input_image):
|
| 195 |
faces = detection_image_pipeline(input_image)
|
| 196 |
-
face2 = faces/255
|
| 197 |
pred = model(np.expand_dims(face2, axis=0))
|
| 198 |
pred = list(pred.values())[0].numpy()[0]
|
| 199 |
-
|
| 200 |
real, fake = pred[0], pred[1]
|
| 201 |
if real > 0.5:
|
| 202 |
-
text2 = "The image is REAL. \n Deepfakes Confidence: " + str(round(100 - (real*100), 3)) + "%"
|
| 203 |
else:
|
| 204 |
-
text2 = "The image is FAKE. \n Deepfakes Confidence: " + str(round(fake*100, 3)) + "%"
|
| 205 |
return text2
|
| 206 |
-
|
|
|
|
| 207 |
def load_audio_model():
|
| 208 |
d_args = {
|
| 209 |
"nb_samp": 64600,
|
|
@@ -216,73 +195,106 @@ def load_audio_model():
|
|
| 216 |
"nb_gru_layer": 3,
|
| 217 |
"nb_classes": 2
|
| 218 |
}
|
| 219 |
-
|
| 220 |
audio_model = RawNet(d_args=d_args, device='cpu')
|
| 221 |
-
|
| 222 |
ckpt = torch.load('RawNet2.pth', map_location=torch.device('cpu'))
|
| 223 |
audio_model.load_state_dict(ckpt)
|
| 224 |
audio_model.eval()
|
| 225 |
return audio_model
|
| 226 |
|
| 227 |
-
audio_label_map = {
|
| 228 |
-
0: "Real audio",
|
| 229 |
-
1: "Fake audio"
|
| 230 |
-
}
|
| 231 |
|
| 232 |
-
RAWNET_SAMPLE_RATE = 16000 # RawNet2 was trained on 16kHz
|
| 233 |
NB_SAMP = 64600 # Exactly 4.0375 seconds at 16kHz
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
def deepfakes_audio_predict(input_audio):
|
| 236 |
"""
|
| 237 |
Gradio gr.Audio() returns a tuple: (sample_rate, numpy_array).
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
1.
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
| 245 |
"""
|
| 246 |
sr, x = input_audio
|
|
|
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
# Step 1: Convert to float32
|
| 251 |
x = x.astype(np.float32)
|
| 252 |
-
|
| 253 |
-
# Step 2: Normalize int16 β [-1.0, 1.0] range
|
| 254 |
if np.abs(x).max() > 1.0:
|
| 255 |
-
x = x / 32768.0
|
| 256 |
|
| 257 |
-
# Step
|
| 258 |
if x.ndim == 2:
|
| 259 |
x = x.mean(axis=1)
|
| 260 |
|
| 261 |
-
# Step
|
| 262 |
-
# RawNet2's SincConv filterbank is hard-coded to 16kHz frequencies.
|
| 263 |
-
# Feeding audio at any other sample rate produces completely wrong filter responses.
|
| 264 |
if sr != RAWNET_SAMPLE_RATE:
|
| 265 |
-
print(f"[Audio] Resampling
|
| 266 |
x = librosa.resample(x, orig_sr=sr, target_sr=RAWNET_SAMPLE_RATE)
|
| 267 |
print(f"[Audio] After resample: {len(x)} samples ({len(x)/RAWNET_SAMPLE_RATE:.2f}s)")
|
| 268 |
|
| 269 |
-
# Step
|
| 270 |
if len(x) < NB_SAMP:
|
| 271 |
x = np.pad(x, (0, NB_SAMP - len(x)), mode='constant')
|
| 272 |
else:
|
| 273 |
x = x[:NB_SAMP]
|
| 274 |
|
| 275 |
-
# Step
|
| 276 |
-
x_pt = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
|
| 277 |
-
|
| 278 |
audio_model = load_audio_model()
|
| 279 |
|
| 280 |
with torch.no_grad():
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
logits_np = logits.detach().numpy()
|
| 284 |
-
result = np.argmax(logits_np)
|
| 285 |
|
| 286 |
-
|
|
|
|
|
|
|
| 287 |
|
| 288 |
-
|
|
|
|
|
|
| 11 |
from rawnet import RawNet
|
| 12 |
|
| 13 |
|
| 14 |
+
# Set random seed for reproducibility.
|
|
|
|
| 15 |
tf.random.set_seed(42)
|
| 16 |
|
| 17 |
# Extract model if not already extracted
|
|
|
|
| 23 |
zip_ref.close()
|
| 24 |
print("Model extracted successfully!")
|
| 25 |
|
| 26 |
+
# Load EfficientNet model using TFSMLayer (Keras 3 compatible)
|
|
|
|
|
|
|
| 27 |
model = tf.keras.layers.TFSMLayer(
|
| 28 |
"efficientnet-b0/",
|
| 29 |
call_endpoint="serving_default"
|
|
|
|
| 34 |
"""
|
| 35 |
Convert any video (e.g. .webm from webcam) to .mp4 using ffmpeg.
|
| 36 |
Returns the path to the converted file, or the original path if already mp4.
|
|
|
|
| 37 |
"""
|
| 38 |
ext = os.path.splitext(input_path)[-1].lower()
|
| 39 |
if ext == ".mp4":
|
|
|
|
| 40 |
cap = cv2.VideoCapture(input_path)
|
| 41 |
ok = cap.isOpened()
|
| 42 |
cap.release()
|
| 43 |
if ok:
|
| 44 |
+
return input_path, False
|
| 45 |
|
|
|
|
| 46 |
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
| 47 |
tmp.close()
|
| 48 |
output_path = tmp.name
|
|
|
|
| 59 |
result = subprocess.run(cmd, capture_output=True)
|
| 60 |
if result.returncode != 0:
|
| 61 |
os.unlink(output_path)
|
| 62 |
+
raise RuntimeError(f"ffmpeg conversion failed:\n{result.stderr.decode()}")
|
| 63 |
+
return output_path, True
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class DetectionPipeline:
|
| 67 |
"""Pipeline class for detecting faces in the frames of a video file."""
|
| 68 |
|
| 69 |
+
def __init__(self, n_frames=None, batch_size=60, resize=None, input_modality='video'):
|
| 70 |
self.n_frames = n_frames
|
| 71 |
self.batch_size = batch_size
|
| 72 |
self.resize = resize
|
|
|
|
| 75 |
def __call__(self, filename):
|
| 76 |
if self.input_modality == 'video':
|
| 77 |
print('Input modality is video.')
|
|
|
|
|
|
|
|
|
|
| 78 |
converted_path, is_temp = convert_to_mp4(filename)
|
| 79 |
print(f"Processing video: {converted_path} (converted={is_temp})")
|
| 80 |
|
|
|
|
| 101 |
if not success:
|
| 102 |
continue
|
| 103 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
|
| 104 |
if self.resize is not None:
|
| 105 |
frame = frame.resize([int(d * self.resize) for d in frame.size])
|
| 106 |
frames.append(frame)
|
|
|
|
| 107 |
if len(frames) % self.batch_size == 0 or j == sample[-1]:
|
| 108 |
face2 = cv2.resize(frame, (224, 224))
|
| 109 |
faces.append(face2)
|
| 110 |
|
| 111 |
v_cap.release()
|
| 112 |
finally:
|
|
|
|
| 113 |
if is_temp and os.path.exists(converted_path):
|
| 114 |
os.unlink(converted_path)
|
| 115 |
|
| 116 |
if len(faces) == 0:
|
| 117 |
raise RuntimeError("No frames could be extracted from the video.")
|
|
|
|
| 118 |
return faces
|
| 119 |
|
| 120 |
elif self.input_modality == 'image':
|
| 121 |
print('Input modality is image.')
|
|
|
|
| 122 |
image = cv2.cvtColor(filename, cv2.COLOR_BGR2RGB)
|
| 123 |
image = cv2.resize(image, (224, 224))
|
| 124 |
return image
|
| 125 |
+
|
| 126 |
elif self.input_modality == 'audio':
|
| 127 |
print("Input modality is audio.")
|
| 128 |
x, sr = librosa.load(filename)
|
| 129 |
x_pt = torch.Tensor(x)
|
| 130 |
x_pt = torch.unsqueeze(x_pt, dim=0)
|
| 131 |
return x_pt
|
| 132 |
+
|
| 133 |
else:
|
| 134 |
raise ValueError("Invalid input modality. Must be either 'video' or 'image'")
|
| 135 |
|
| 136 |
+
|
| 137 |
detection_video_pipeline = DetectionPipeline(n_frames=5, batch_size=1, input_modality='video')
|
| 138 |
detection_image_pipeline = DetectionPipeline(batch_size=1, input_modality='image')
|
| 139 |
|
|
|
|
| 140 |
|
| 141 |
+
def deepfakes_video_predict(input_video):
|
| 142 |
faces = detection_video_pipeline(input_video)
|
| 143 |
total = 0
|
| 144 |
real_res = []
|
| 145 |
fake_res = []
|
| 146 |
|
| 147 |
for face in faces:
|
| 148 |
+
face2 = face / 255
|
|
|
|
| 149 |
pred = model(np.expand_dims(face2, axis=0))
|
| 150 |
pred = list(pred.values())[0].numpy()[0]
|
|
|
|
| 151 |
real, fake = pred[0], pred[1]
|
| 152 |
real_res.append(real)
|
| 153 |
fake_res.append(fake)
|
| 154 |
+
total += 1
|
|
|
|
|
|
|
| 155 |
pred2 = pred[1]
|
|
|
|
| 156 |
if pred2 > 0.5:
|
| 157 |
+
fake += 1
|
| 158 |
else:
|
| 159 |
+
real += 1
|
| 160 |
+
|
| 161 |
real_mean = np.mean(real_res)
|
| 162 |
fake_mean = np.mean(fake_res)
|
| 163 |
print(f"Real Faces: {real_mean}")
|
| 164 |
print(f"Fake Faces: {fake_mean}")
|
|
|
|
| 165 |
|
| 166 |
if real_mean >= 0.5:
|
| 167 |
+
text = "The video is REAL. \n Deepfakes Confidence: " + str(round(100 - (real_mean * 100), 3)) + "%"
|
| 168 |
else:
|
| 169 |
+
text = "The video is FAKE. \n Deepfakes Confidence: " + str(round(fake_mean * 100, 3)) + "%"
|
|
|
|
| 170 |
return text
|
| 171 |
|
| 172 |
|
| 173 |
def deepfakes_image_predict(input_image):
|
| 174 |
faces = detection_image_pipeline(input_image)
|
| 175 |
+
face2 = faces / 255
|
| 176 |
pred = model(np.expand_dims(face2, axis=0))
|
| 177 |
pred = list(pred.values())[0].numpy()[0]
|
|
|
|
| 178 |
real, fake = pred[0], pred[1]
|
| 179 |
if real > 0.5:
|
| 180 |
+
text2 = "The image is REAL. \n Deepfakes Confidence: " + str(round(100 - (real * 100), 3)) + "%"
|
| 181 |
else:
|
| 182 |
+
text2 = "The image is FAKE. \n Deepfakes Confidence: " + str(round(fake * 100, 3)) + "%"
|
| 183 |
return text2
|
| 184 |
+
|
| 185 |
+
|
| 186 |
def load_audio_model():
|
| 187 |
d_args = {
|
| 188 |
"nb_samp": 64600,
|
|
|
|
| 195 |
"nb_gru_layer": 3,
|
| 196 |
"nb_classes": 2
|
| 197 |
}
|
|
|
|
| 198 |
audio_model = RawNet(d_args=d_args, device='cpu')
|
|
|
|
| 199 |
ckpt = torch.load('RawNet2.pth', map_location=torch.device('cpu'))
|
| 200 |
audio_model.load_state_dict(ckpt)
|
| 201 |
audio_model.eval()
|
| 202 |
return audio_model
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
+
RAWNET_SAMPLE_RATE = 16000 # RawNet2 was trained strictly on 16kHz β never change
|
| 206 |
NB_SAMP = 64600 # Exactly 4.0375 seconds at 16kHz
|
| 207 |
|
| 208 |
+
# βββ Confidence thresholds for 3-class labelling ββββββββββββββββββββββββββββ
|
| 209 |
+
# RawNet2 has 2 output classes (real / fake). We derive a 3rd class
|
| 210 |
+
# "AI Synthesized" from the confidence score:
|
| 211 |
+
#
|
| 212 |
+
# real_prob >= REAL_THRESHOLD β Genuine human voice
|
| 213 |
+
# fake_prob >= FAKE_THRESHOLD β Manipulated / spliced audio
|
| 214 |
+
# anything in between β AI Synthesized / TTS / Voice-cloned
|
| 215 |
+
#
|
| 216 |
+
# Why this works: TTS and voice-clone audio confuses RawNet2 β it produces
|
| 217 |
+
# low-confidence outputs for both classes because it was trained on older
|
| 218 |
+
# spoofing attacks. That uncertainty is the signal we exploit.
|
| 219 |
+
REAL_THRESHOLD = 0.75
|
| 220 |
+
FAKE_THRESHOLD = 0.75
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def classify_audio_3class(real_prob: float, fake_prob: float) -> str:
|
| 224 |
+
"""
|
| 225 |
+
Map RawNet2 2-class probabilities β 3-class human-readable label.
|
| 226 |
+
|
| 227 |
+
Classes:
|
| 228 |
+
- Real Human Voice : model is confident it's real
|
| 229 |
+
- AI Synthesized : model is uncertain (TTS / voice-clone zone)
|
| 230 |
+
- Fake / Manipulated : model is confident it's fake (spliced, replayed)
|
| 231 |
+
"""
|
| 232 |
+
print(f"[Audio] real_prob={real_prob:.4f} fake_prob={fake_prob:.4f}")
|
| 233 |
+
|
| 234 |
+
if real_prob >= REAL_THRESHOLD:
|
| 235 |
+
confidence = round(real_prob * 100, 2)
|
| 236 |
+
return f"β
Real Human Voice\nConfidence: {confidence}%"
|
| 237 |
+
|
| 238 |
+
elif fake_prob >= FAKE_THRESHOLD:
|
| 239 |
+
confidence = round(fake_prob * 100, 2)
|
| 240 |
+
return f"π¨ Fake / Manipulated Audio\nConfidence: {confidence}%"
|
| 241 |
+
|
| 242 |
+
else:
|
| 243 |
+
# Low confidence on both sides β hallmark of modern TTS / voice cloning
|
| 244 |
+
ai_confidence = round(fake_prob * 100, 2)
|
| 245 |
+
return (
|
| 246 |
+
f"π€ AI Synthesized / Voice Cloned\n"
|
| 247 |
+
f"Confidence: {ai_confidence}%\n"
|
| 248 |
+
f"(Model uncertainty indicates TTS or neural voice cloning)"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
def deepfakes_audio_predict(input_audio):
|
| 253 |
"""
|
| 254 |
Gradio gr.Audio() returns a tuple: (sample_rate, numpy_array).
|
| 255 |
+
|
| 256 |
+
Pipeline:
|
| 257 |
+
1. float32 conversion + int16 normalisation
|
| 258 |
+
2. Stereo β mono
|
| 259 |
+
3. Resample to 16000 Hz β critical: RawNet2 SincConv assumes 16kHz
|
| 260 |
+
4. Pad / trim to NB_SAMP (64600) samples
|
| 261 |
+
5. RawNet2 inference β log-softmax β probabilities
|
| 262 |
+
6. 3-class decision via confidence thresholds
|
| 263 |
"""
|
| 264 |
sr, x = input_audio
|
| 265 |
+
print(f"[Audio] Input SR={sr} Hz | samples={len(x)} | dtype={x.dtype}")
|
| 266 |
|
| 267 |
+
# Step 1 β float32 + normalise
|
|
|
|
|
|
|
| 268 |
x = x.astype(np.float32)
|
|
|
|
|
|
|
| 269 |
if np.abs(x).max() > 1.0:
|
| 270 |
+
x = x / 32768.0 # int16 β [-1, 1]
|
| 271 |
|
| 272 |
+
# Step 2 β stereo β mono (must precede librosa.resample which needs 1-D)
|
| 273 |
if x.ndim == 2:
|
| 274 |
x = x.mean(axis=1)
|
| 275 |
|
| 276 |
+
# Step 3 β resample to 16 kHz (THE root-cause fix)
|
|
|
|
|
|
|
| 277 |
if sr != RAWNET_SAMPLE_RATE:
|
| 278 |
+
print(f"[Audio] Resampling {sr} Hz β {RAWNET_SAMPLE_RATE} Hz β¦")
|
| 279 |
x = librosa.resample(x, orig_sr=sr, target_sr=RAWNET_SAMPLE_RATE)
|
| 280 |
print(f"[Audio] After resample: {len(x)} samples ({len(x)/RAWNET_SAMPLE_RATE:.2f}s)")
|
| 281 |
|
| 282 |
+
# Step 4 β pad or trim to exactly NB_SAMP
|
| 283 |
if len(x) < NB_SAMP:
|
| 284 |
x = np.pad(x, (0, NB_SAMP - len(x)), mode='constant')
|
| 285 |
else:
|
| 286 |
x = x[:NB_SAMP]
|
| 287 |
|
| 288 |
+
# Step 5 β inference
|
| 289 |
+
x_pt = torch.tensor(x, dtype=torch.float32).unsqueeze(0) # [1, NB_SAMP]
|
|
|
|
| 290 |
audio_model = load_audio_model()
|
| 291 |
|
| 292 |
with torch.no_grad():
|
| 293 |
+
log_probs = audio_model(x_pt) # log-softmax output
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
+
probs = torch.exp(log_probs).numpy()[0] # convert log β actual probabilities
|
| 296 |
+
real_prob = float(probs[0])
|
| 297 |
+
fake_prob = float(probs[1])
|
| 298 |
|
| 299 |
+
# Step 6 β 3-class label
|
| 300 |
+
return classify_audio_3class(real_prob, fake_prob)
|