|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
import librosa |
|
|
import numpy as np |
|
|
import tempfile |
|
|
import shutil |
|
|
import uuid |
|
|
import base64 as b64 |
|
|
from fastapi import FastAPI, HTTPException, File, UploadFile, Header, Depends |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field, validator |
|
|
from typing import Optional |
|
|
from transformers import WavLMModel |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print("Using device:", DEVICE) |
|
|
|
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
|
if API_KEY: |
|
|
print("✓ API key loaded from environment variable") |
|
|
else: |
|
|
print("⚠️ WARNING: API_KEY not set! Set API_KEY environment variable in HF Spaces.") |
|
|
|
|
|
SUPPORTED_LANGUAGES = ["Tamil", "English", "Hindi", "Malayalam", "Telugu"] |
|
|
|
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
TARGET_DURATION = 5.0 |
|
|
MAX_AUDIO_DURATION = 60.0 |
|
|
SLIDING_WINDOW_HOP = 2.5 |
|
|
|
|
|
|
|
|
NORM_TYPE = "peak" |
|
|
RMS_TARGET = 0.1 |
|
|
SILENCE_THRESHOLD = 1e-4 |
|
|
|
|
|
|
|
|
MIN_RMS_ENERGY = 0.005 |
|
|
MAX_SILENCE_RATIO = 0.9 |
|
|
MIN_SPEECH_PROB = 0.3 |
|
|
MAX_ZERO_CROSSING_RATE = 0.7 |
|
|
MIN_ZERO_CROSSING_RATE = 0.02 |
|
|
MAX_SPECTRAL_CENTROID = 5000 |
|
|
MIN_SPECTRAL_CENTROID = 150 |
|
|
MAX_CLIPPING_RATIO = 0.02 |
|
|
|
|
|
|
|
|
USE_DENOISE = False |
|
|
DENOISE_N_FFT = 1024 |
|
|
DENOISE_HOP_LENGTH = 256 |
|
|
DENOISE_NOISE_PERCENTILE = 10 |
|
|
DENOISE_THRESHOLD_MULT = 1.5 |
|
|
DENOISE_ATTENUATION = 0.2 |
|
|
|
|
|
USE_BANDPASS = True |
|
|
HIGHPASS_CUTOFF_HZ = 80.0 |
|
|
LOWPASS_CUTOFF_HZ = 7800.0 |
|
|
|
|
|
|
|
|
DROPOUT_P = 0.3 |
|
|
|
|
|
|
|
|
AASIST_WEIGHT = 0.6 |
|
|
OCSOFT_WEIGHT = 0.4 |
|
|
|
|
|
|
|
|
OPTIMAL_THRESHOLD = 0.5 |
|
|
|
|
|
|
|
|
def _apply_bandpass_torch(wav_t: torch.Tensor, sr: int) -> torch.Tensor: |
|
|
"""Bandpass filter to focus on speech band and reduce rumble/hiss.""" |
|
|
if not USE_BANDPASS: |
|
|
return wav_t |
|
|
wav_t = torchaudio.functional.highpass_biquad(wav_t, sr, cutoff_freq=HIGHPASS_CUTOFF_HZ) |
|
|
wav_t = torchaudio.functional.lowpass_biquad(wav_t, sr, cutoff_freq=LOWPASS_CUTOFF_HZ) |
|
|
return wav_t |
|
|
|
|
|
|
|
|
def _validate_audio_quality(wav_np: np.ndarray, sr: int) -> dict: |
|
|
""" |
|
|
Validate audio quality and content. Returns dict with validation results. |
|
|
Raises ValueError if audio should be rejected. |
|
|
""" |
|
|
if len(wav_np) == 0: |
|
|
raise ValueError("Audio is empty") |
|
|
|
|
|
|
|
|
rms = np.sqrt(np.mean(wav_np ** 2)) |
|
|
if rms < MIN_RMS_ENERGY: |
|
|
raise ValueError(f"Audio is too quiet (RMS: {rms:.6f}). Please provide clear audio.") |
|
|
|
|
|
|
|
|
frame_length = int(0.02 * sr) |
|
|
hop_length = frame_length // 2 |
|
|
frames = librosa.util.frame(wav_np, frame_length=frame_length, hop_length=hop_length) |
|
|
frame_rms = np.sqrt(np.mean(frames ** 2, axis=0)) |
|
|
silence_ratio = np.sum(frame_rms < MIN_RMS_ENERGY * 0.5) / len(frame_rms) |
|
|
|
|
|
if silence_ratio > MAX_SILENCE_RATIO: |
|
|
raise ValueError(f"Audio contains {silence_ratio*100:.1f}% silence. Please provide clear speech.") |
|
|
|
|
|
|
|
|
clipping_ratio = np.sum(np.abs(wav_np) > 0.99) / len(wav_np) |
|
|
if clipping_ratio > MAX_CLIPPING_RATIO: |
|
|
raise ValueError(f"Audio is clipped/distorted ({clipping_ratio*100:.1f}% samples). Please provide undistorted audio.") |
|
|
|
|
|
|
|
|
|
|
|
non_speech_indicators = 0 |
|
|
|
|
|
|
|
|
zcr = np.mean(librosa.zero_crossings(wav_np)) |
|
|
|
|
|
if zcr > MAX_ZERO_CROSSING_RATE: |
|
|
non_speech_indicators += 1 |
|
|
|
|
|
if zcr < MIN_ZERO_CROSSING_RATE: |
|
|
non_speech_indicators += 1 |
|
|
|
|
|
|
|
|
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=wav_np, sr=sr)) |
|
|
|
|
|
if spectral_centroid > MAX_SPECTRAL_CENTROID: |
|
|
non_speech_indicators += 1 |
|
|
|
|
|
if spectral_centroid < MIN_SPECTRAL_CENTROID: |
|
|
non_speech_indicators += 1 |
|
|
|
|
|
|
|
|
rolloff = np.mean(librosa.feature.spectral_rolloff(y=wav_np, sr=sr, roll_percent=0.85)) |
|
|
|
|
|
|
|
|
if rolloff > 10000: |
|
|
non_speech_indicators += 1 |
|
|
|
|
|
if rolloff < 800: |
|
|
non_speech_indicators += 1 |
|
|
|
|
|
|
|
|
if non_speech_indicators >= 2: |
|
|
raise ValueError(f"Audio does not appear to be clear speech (ZCR: {zcr:.3f}, Centroid: {spectral_centroid:.0f}Hz, Rolloff: {rolloff:.0f}Hz). Please provide speech-only audio.") |
|
|
|
|
|
|
|
|
return { |
|
|
"rms": float(rms), |
|
|
"silence_ratio": float(silence_ratio), |
|
|
"zero_crossing_rate": float(zcr), |
|
|
"spectral_centroid": float(spectral_centroid), |
|
|
"spectral_rolloff": float(rolloff), |
|
|
"clipping_ratio": float(clipping_ratio) |
|
|
} |
|
|
|
|
|
|
|
|
def _denoise_spectral_gate_np(wav_np: np.ndarray, sr: int) -> np.ndarray: |
|
|
"""Mild spectral gating denoise (keeps speech; reduces steady background noise).""" |
|
|
if not USE_DENOISE: |
|
|
return wav_np |
|
|
if wav_np.size == 0: |
|
|
return wav_np |
|
|
if not np.isfinite(wav_np).all(): |
|
|
return wav_np |
|
|
|
|
|
stft = librosa.stft(wav_np, n_fft=DENOISE_N_FFT, hop_length=DENOISE_HOP_LENGTH) |
|
|
mag = np.abs(stft) |
|
|
phase = np.exp(1j * np.angle(stft)) |
|
|
|
|
|
noise_floor = np.percentile(mag, DENOISE_NOISE_PERCENTILE, axis=1, keepdims=True) |
|
|
thresh = noise_floor * float(DENOISE_THRESHOLD_MULT) |
|
|
|
|
|
mask = (mag >= thresh).astype(np.float32) |
|
|
mag_d = mag * mask + mag * (1.0 - mask) * float(DENOISE_ATTENUATION) |
|
|
|
|
|
stft_d = mag_d * phase |
|
|
wav_out = librosa.istft(stft_d, hop_length=DENOISE_HOP_LENGTH, length=len(wav_np)) |
|
|
return wav_out.astype(np.float32) |
|
|
|
|
|
|
|
|
def _sniff_audio_ext(audio_bytes: bytes) -> str: |
|
|
"""Best-effort format sniffing for base64/bytes inputs.""" |
|
|
if not audio_bytes: |
|
|
return ".wav" |
|
|
head = audio_bytes[:64] |
|
|
if head.startswith(b"RIFF") and b"WAVE" in head: |
|
|
return ".wav" |
|
|
if head.startswith(b"ID3") or (len(head) >= 2 and head[0] == 0xFF and (head[1] & 0xE0) == 0xE0): |
|
|
return ".mp3" |
|
|
return ".mp3" |
|
|
|
|
|
|
|
|
def _load_audio_any(audio_input, *, is_base64: bool, base64_format: str | None = None): |
|
|
"""Load audio from a filepath or base64 string. Returns (wav_np, sr).""" |
|
|
if not is_base64: |
|
|
path = str(audio_input) |
|
|
try: |
|
|
wav, sr = librosa.load(path, sr=SAMPLE_RATE, mono=True) |
|
|
return wav, sr |
|
|
except Exception as e: |
|
|
if path.lower().endswith(".mp3") and shutil.which("ffmpeg") is None: |
|
|
raise ValueError( |
|
|
"MP3 decoding failed and ffmpeg was not found." |
|
|
) from e |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
audio_bytes = b64.b64decode(audio_input) |
|
|
except Exception as e: |
|
|
raise ValueError("Invalid base64 audio") from e |
|
|
|
|
|
ext = None |
|
|
if base64_format is not None: |
|
|
ext = ("." + base64_format.lower().lstrip(".")) |
|
|
else: |
|
|
ext = _sniff_audio_ext(audio_bytes) |
|
|
|
|
|
tmp_dir = tempfile.gettempdir() |
|
|
tmp_path = os.path.join(tmp_dir, f"tmp_audio_{uuid.uuid4().hex}{ext}") |
|
|
try: |
|
|
with open(tmp_path, "wb") as f: |
|
|
f.write(audio_bytes) |
|
|
wav, sr = librosa.load(tmp_path, sr=SAMPLE_RATE, mono=True) |
|
|
return wav, sr |
|
|
except Exception as e: |
|
|
if ext == ".mp3" and shutil.which("ffmpeg") is None: |
|
|
raise ValueError( |
|
|
"Base64 MP3 decoding failed and ffmpeg was not found." |
|
|
) from e |
|
|
raise ValueError(f"Error decoding base64 audio ({ext}): {str(e)}") from e |
|
|
finally: |
|
|
try: |
|
|
if os.path.exists(tmp_path): |
|
|
os.remove(tmp_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
class AASISTHead(nn.Module): |
|
|
"""AASIST-inspired classification head with attention + regularization.""" |
|
|
|
|
|
def __init__(self, dim=768, dropout=DROPOUT_P, n_heads=8): |
|
|
super().__init__() |
|
|
self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True) |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, 256), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(256, 64), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(64, 1), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
attn_out, _ = self.attn(x, x, x, need_weights=False) |
|
|
x = self.norm(x + attn_out) |
|
|
pooled = x.mean(dim=1) |
|
|
return self.mlp(pooled) |
|
|
|
|
|
|
|
|
class OCSoftmaxHead(nn.Module): |
|
|
"""Regularized one-class style head (trained with BCE).""" |
|
|
|
|
|
def __init__(self, dim=768, dropout=DROPOUT_P): |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, 256), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(256, 1), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
pooled = self.norm(x.mean(dim=1)) |
|
|
return self.mlp(pooled) |
|
|
|
|
|
|
|
|
|
|
|
print("Loading models...") |
|
|
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base") |
|
|
wavlm.to(DEVICE) |
|
|
wavlm.eval() |
|
|
for param in wavlm.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
aasist = AASISTHead().to(DEVICE) |
|
|
ocsoft = OCSoftmaxHead().to(DEVICE) |
|
|
|
|
|
|
|
|
def load_state_dict_flexible(model, state_dict): |
|
|
"""Load state dict, handling DataParallel 'module.' prefix if present.""" |
|
|
|
|
|
if any(k.startswith('module.') for k in state_dict.keys()): |
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
new_key = k.replace('module.', '') |
|
|
new_state_dict[new_key] = v |
|
|
model.load_state_dict(new_state_dict) |
|
|
else: |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "best_model.pt" |
|
|
if os.path.exists(MODEL_PATH): |
|
|
print(f"Loading trained weights from {MODEL_PATH}") |
|
|
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) |
|
|
load_state_dict_flexible(wavlm, checkpoint['wavlm']) |
|
|
load_state_dict_flexible(aasist, checkpoint['aasist']) |
|
|
load_state_dict_flexible(ocsoft, checkpoint['ocsoft']) |
|
|
print("Trained weights loaded successfully!") |
|
|
else: |
|
|
print("Warning: No trained weights found. Using randomly initialized heads.") |
|
|
|
|
|
|
|
|
THRESHOLD_PATH = "optimal_threshold.txt" |
|
|
if os.path.exists(THRESHOLD_PATH): |
|
|
with open(THRESHOLD_PATH, 'r') as f: |
|
|
OPTIMAL_THRESHOLD = float(f.read().strip()) |
|
|
print(f"Loaded optimal threshold: {OPTIMAL_THRESHOLD:.4f}") |
|
|
else: |
|
|
print(f"Using default threshold: {OPTIMAL_THRESHOLD:.4f}") |
|
|
|
|
|
aasist.eval() |
|
|
ocsoft.eval() |
|
|
|
|
|
|
|
|
def _extract_crop(wav: np.ndarray, target_length: int, crop_type: str = "center", seed: int = None) -> np.ndarray: |
|
|
""" |
|
|
Extract a crop from audio. |
|
|
crop_type: 'center', 'random', 'start', 'end' |
|
|
""" |
|
|
current_length = len(wav) |
|
|
|
|
|
if current_length <= target_length: |
|
|
|
|
|
pad_length = target_length - current_length |
|
|
if pad_length > current_length: |
|
|
|
|
|
repeats = (target_length // current_length) + 1 |
|
|
wav = np.tile(wav, repeats) |
|
|
current_length = len(wav) |
|
|
pad_length = target_length - current_length |
|
|
|
|
|
if pad_length > 0: |
|
|
pad_left = pad_length // 2 |
|
|
pad_right = pad_length - pad_left |
|
|
wav = np.pad(wav, (pad_left, pad_right), mode='reflect') |
|
|
return wav[:target_length] |
|
|
|
|
|
|
|
|
if crop_type == "center": |
|
|
start = (current_length - target_length) // 2 |
|
|
elif crop_type == "start": |
|
|
start = 0 |
|
|
elif crop_type == "end": |
|
|
start = current_length - target_length |
|
|
elif crop_type == "random": |
|
|
if seed is not None: |
|
|
np.random.seed(seed) |
|
|
start = np.random.randint(0, current_length - target_length + 1) |
|
|
else: |
|
|
start = (current_length - target_length) // 2 |
|
|
|
|
|
return wav[start:start + target_length] |
|
|
|
|
|
|
|
|
def preprocess_audio(audio_input, is_base64=False, base64_format: str | None = None, return_multiple=False): |
|
|
""" |
|
|
Preprocess audio for inference. |
|
|
For short audio (<=5s): pads to 5 seconds |
|
|
For long audio (>5s): uses sliding window to process entire audio |
|
|
|
|
|
Returns: |
|
|
- Single crop (tensor) if return_multiple=False |
|
|
- List of windows + duration if return_multiple=True |
|
|
""" |
|
|
try: |
|
|
wav, sr = _load_audio_any(audio_input, is_base64=is_base64, base64_format=base64_format) |
|
|
|
|
|
if len(wav) == 0: |
|
|
raise ValueError("Empty audio file") |
|
|
if not np.isfinite(wav).all(): |
|
|
raise ValueError("Invalid audio values") |
|
|
|
|
|
|
|
|
audio_duration = len(wav) / sr |
|
|
if audio_duration > MAX_AUDIO_DURATION: |
|
|
raise ValueError(f"Audio too long ({audio_duration:.1f}s). Maximum duration is {MAX_AUDIO_DURATION}s.") |
|
|
|
|
|
|
|
|
validation_result = _validate_audio_quality(wav, sr) |
|
|
|
|
|
|
|
|
|
|
|
wav = _denoise_spectral_gate_np(wav.astype(np.float32), sr) |
|
|
|
|
|
|
|
|
wav_t = torch.tensor(wav).float() |
|
|
wav_t = _apply_bandpass_torch(wav_t, sr) |
|
|
wav = wav_t.cpu().numpy() |
|
|
|
|
|
|
|
|
if abs(wav).max() < SILENCE_THRESHOLD: |
|
|
pass |
|
|
elif NORM_TYPE == "peak": |
|
|
wav = wav / max(abs(wav).max(), 1e-6) |
|
|
elif NORM_TYPE == "rms": |
|
|
rms = np.sqrt(np.mean(wav**2)) |
|
|
if rms > 1e-6: |
|
|
wav = wav * (RMS_TARGET / rms) |
|
|
wav = np.clip(wav, -1.0, 1.0) |
|
|
|
|
|
|
|
|
target_length = int(TARGET_DURATION * sr) |
|
|
current_length = len(wav) |
|
|
windows = [] |
|
|
|
|
|
if audio_duration <= TARGET_DURATION: |
|
|
|
|
|
window = _extract_crop(wav, target_length, crop_type="center") |
|
|
windows.append(window) |
|
|
elif not return_multiple: |
|
|
|
|
|
window = _extract_crop(wav, target_length, crop_type="center") |
|
|
windows.append(window) |
|
|
else: |
|
|
|
|
|
hop_length = int(SLIDING_WINDOW_HOP * sr) |
|
|
|
|
|
|
|
|
start_positions = list(range(0, current_length - target_length + 1, hop_length)) |
|
|
|
|
|
|
|
|
if start_positions[-1] != current_length - target_length: |
|
|
start_positions.append(current_length - target_length) |
|
|
|
|
|
|
|
|
for start in start_positions: |
|
|
window = wav[start:start + target_length] |
|
|
windows.append(window) |
|
|
|
|
|
|
|
|
normalized_windows = [] |
|
|
for window in windows: |
|
|
window_tensor = torch.tensor(window).float().unsqueeze(0).to(DEVICE) |
|
|
normalized_windows.append(window_tensor) |
|
|
|
|
|
if return_multiple: |
|
|
return normalized_windows, audio_duration |
|
|
else: |
|
|
return normalized_windows[0] |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"Error preprocessing audio: {str(e)}") |
|
|
|
|
|
|
|
|
def detect_ai_voice(audio_input, is_base64=False, language="English", threshold=None, base64_format: str | None = None): |
|
|
""" |
|
|
Detect if voice is AI-generated or human. |
|
|
For long audio: processes entire audio using sliding windows with 50% overlap. |
|
|
""" |
|
|
try: |
|
|
if threshold is None: |
|
|
threshold = OPTIMAL_THRESHOLD |
|
|
|
|
|
|
|
|
wav_windows, audio_duration = preprocess_audio(audio_input, is_base64=is_base64, base64_format=base64_format, return_multiple=True) |
|
|
|
|
|
all_scores = [] |
|
|
with torch.no_grad(): |
|
|
for wav in wav_windows: |
|
|
feats = wavlm(wav).last_hidden_state |
|
|
|
|
|
score_aasist = float(torch.sigmoid(aasist(feats)).item()) |
|
|
score_oc = float(torch.sigmoid(ocsoft(feats)).item()) |
|
|
|
|
|
window_score = float(AASIST_WEIGHT * score_aasist + OCSOFT_WEIGHT * score_oc) |
|
|
all_scores.append(window_score) |
|
|
|
|
|
|
|
|
final_score = float(np.mean(all_scores)) |
|
|
|
|
|
|
|
|
if final_score >= 0.5: |
|
|
classification = "AI_GENERATED" |
|
|
confidence = final_score |
|
|
else: |
|
|
classification = "HUMAN" |
|
|
confidence = 1.0 - final_score |
|
|
|
|
|
|
|
|
|
|
|
confidence = min(0.8 + (confidence - 0.5) * 0.36, 0.98) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"classification": str(classification), |
|
|
"confidenceScore": float(confidence) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
raise ValueError(f"Error processing audio: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="AI Audio Detector API", |
|
|
description="API for detecting AI-generated vs human speech", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
async def verify_api_key(x_api_key: str = Header(...)): |
|
|
""" |
|
|
Validate API key from request headers against environment variable. |
|
|
User supplies their key in x-api-key header, which is checked against API_KEY env variable. |
|
|
""" |
|
|
if not x_api_key or len(x_api_key.strip()) == 0: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail={"status": "error", "message": "API key is required in x-api-key header"} |
|
|
) |
|
|
|
|
|
if x_api_key != API_KEY: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail={"status": "error", "message": "Invalid API key"} |
|
|
) |
|
|
|
|
|
return x_api_key |
|
|
|
|
|
|
|
|
|
|
|
class Base64AudioRequest(BaseModel): |
|
|
language: str = Field(..., description="Language of the audio: Tamil, English, Hindi, Malayalam, Telugu") |
|
|
audioFormat: str = Field(..., description="Audio format (mp3)") |
|
|
audioBase64: str = Field(..., description="Base64 encoded audio file") |
|
|
threshold: Optional[float] = Field(None, description="Custom detection threshold (0.0-1.0)") |
|
|
|
|
|
@validator('language') |
|
|
def validate_language(cls, v): |
|
|
|
|
|
language_lower = v.lower() |
|
|
for lang in SUPPORTED_LANGUAGES: |
|
|
if lang.lower() == language_lower: |
|
|
return lang |
|
|
raise ValueError(f"Language must be one of: {', '.join(SUPPORTED_LANGUAGES)}") |
|
|
|
|
|
@validator('audioFormat') |
|
|
def validate_format(cls, v): |
|
|
if v.lower() != "mp3": |
|
|
raise ValueError("Only MP3 format is supported") |
|
|
return v.lower() |
|
|
|
|
|
class Config: |
|
|
json_schema_extra = { |
|
|
"example": { |
|
|
"language": "Tamil", |
|
|
"audioFormat": "mp3", |
|
|
"audioBase64": "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU2LjM2LjEwMAAAAAAA..." |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class DetectionResponse(BaseModel): |
|
|
status: str = Field(..., description="Status of the request: 'success' or 'error'") |
|
|
classification: str = Field(..., description="Classification: 'AI_GENERATED' or 'HUMAN'") |
|
|
confidenceScore: float = Field(..., description="Confidence score (0.0-1.0). Higher values indicate greater confidence in the classification") |
|
|
|
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
status: str = Field("error", description="Status of the request") |
|
|
message: str = Field(..., description="Error message") |
|
|
|
|
|
|
|
|
class Base64EncodeResponse(BaseModel): |
|
|
status: str = Field(..., description="Status of the request") |
|
|
filename: str = Field(..., description="Original filename") |
|
|
fileSize: int = Field(..., description="File size in bytes") |
|
|
base64Length: int = Field(..., description="Length of base64 string") |
|
|
audioBase64: str = Field(..., description="Base64 encoded audio string") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint""" |
|
|
return { |
|
|
"message": "AI Audio Detector API - Voice Classification System", |
|
|
"version": "1.0.0", |
|
|
"description": "Detects AI-generated vs Human voice across multiple languages", |
|
|
"supported_languages": SUPPORTED_LANGUAGES, |
|
|
"max_audio_duration": f"{MAX_AUDIO_DURATION}s", |
|
|
"processing_method": "Sliding window analysis for complete audio coverage", |
|
|
"authentication": "Required: x-api-key header", |
|
|
"endpoints": { |
|
|
"POST /api/detect-from-file": "Upload audio file directly - easiest method! (requires API key)", |
|
|
"POST /api/voice-detection": "Detect AI voice from base64 MP3 audio (requires API key)", |
|
|
"POST /api/encode-to-base64": "Encode audio file to base64 string (requires API key)", |
|
|
"GET /health": "Health check endpoint", |
|
|
"GET /docs": "Interactive API documentation" |
|
|
}, |
|
|
"classification_types": ["AI_GENERATED", "HUMAN"], |
|
|
"confidence_range": "Confidence scores range from 0.0 to 1.0" |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"device": str(DEVICE), |
|
|
"model_loaded": os.path.exists(MODEL_PATH), |
|
|
"threshold": OPTIMAL_THRESHOLD, |
|
|
"supported_languages": SUPPORTED_LANGUAGES, |
|
|
"api_version": "1.0.0" |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/api/voice-detection", response_model=DetectionResponse) |
|
|
async def voice_detection( |
|
|
request: Base64AudioRequest, |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
""" |
|
|
Detect AI-generated voice from base64 encoded audio |
|
|
|
|
|
Required Headers: |
|
|
- **x-api-key**: Your API key for authentication |
|
|
|
|
|
Request Body: |
|
|
- **language**: Language of the audio (Tamil, English, Hindi, Malayalam, Telugu) |
|
|
- **audioFormat**: Audio format (mp3) |
|
|
- **audioBase64**: Base64 encoded audio file |
|
|
""" |
|
|
try: |
|
|
result = detect_ai_voice( |
|
|
audio_input=request.audioBase64, |
|
|
is_base64=True, |
|
|
language=request.language, |
|
|
threshold=request.threshold, |
|
|
base64_format=request.audioFormat |
|
|
) |
|
|
return DetectionResponse(**result) |
|
|
|
|
|
except ValueError as e: |
|
|
|
|
|
return DetectionResponse( |
|
|
status="error", |
|
|
classification="HUMAN", |
|
|
confidenceScore=0.0 |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
return DetectionResponse( |
|
|
status="error", |
|
|
classification="HUMAN", |
|
|
confidenceScore=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/detect/base64", response_model=DetectionResponse) |
|
|
async def detect_from_base64( |
|
|
request: Base64AudioRequest, |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
""" |
|
|
Legacy endpoint - use /api/voice-detection instead |
|
|
""" |
|
|
return await voice_detection(request, api_key) |
|
|
|
|
|
|
|
|
@app.post("/api/detect-from-file", response_model=DetectionResponse) |
|
|
async def detect_from_file( |
|
|
file: UploadFile = File(..., description="Audio file (MP3, WAV, FLAC, etc.)"), |
|
|
language: str = "English", |
|
|
threshold: Optional[float] = None, |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
""" |
|
|
Direct audio file upload endpoint - no base64 encoding needed! |
|
|
|
|
|
Upload an audio file directly and get AI detection results. |
|
|
The API handles all preprocessing automatically. |
|
|
|
|
|
Required Headers: |
|
|
- **x-api-key**: Your API key for authentication |
|
|
|
|
|
Form Data: |
|
|
- **file**: Audio file to analyze (MP3, WAV, FLAC, etc.) |
|
|
- **language**: Language of the audio (optional, default: English) |
|
|
- **threshold**: Custom detection threshold 0.0-1.0 (optional) |
|
|
|
|
|
Returns the same DetectionResponse as /api/voice-detection |
|
|
""" |
|
|
|
|
|
language_lower = language.lower() |
|
|
validated_language = None |
|
|
for lang in SUPPORTED_LANGUAGES: |
|
|
if lang.lower() == language_lower: |
|
|
validated_language = lang |
|
|
break |
|
|
|
|
|
if validated_language is None: |
|
|
return DetectionResponse( |
|
|
status="error", |
|
|
classification="HUMAN", |
|
|
confidenceScore=0.0 |
|
|
) |
|
|
|
|
|
language = validated_language |
|
|
|
|
|
|
|
|
if threshold is not None and (threshold < 0.0 or threshold > 1.0): |
|
|
return DetectionResponse( |
|
|
status="error", |
|
|
classification="HUMAN", |
|
|
confidenceScore=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
|
file_ext = os.path.splitext(file.filename or "audio.mp3")[1] or ".mp3" |
|
|
temp_path = os.path.join(temp_dir, f"upload_{uuid.uuid4().hex}{file_ext}") |
|
|
|
|
|
try: |
|
|
|
|
|
content = await file.read() |
|
|
with open(temp_path, "wb") as f: |
|
|
f.write(content) |
|
|
|
|
|
|
|
|
result = detect_ai_voice( |
|
|
audio_input=temp_path, |
|
|
is_base64=False, |
|
|
language=language, |
|
|
threshold=threshold, |
|
|
base64_format=None |
|
|
) |
|
|
|
|
|
return DetectionResponse(**result) |
|
|
|
|
|
except ValueError as e: |
|
|
|
|
|
return DetectionResponse( |
|
|
status="error", |
|
|
classification="HUMAN", |
|
|
confidenceScore=0.0 |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
return DetectionResponse( |
|
|
status="error", |
|
|
classification="HUMAN", |
|
|
confidenceScore=0.0 |
|
|
) |
|
|
finally: |
|
|
|
|
|
try: |
|
|
if os.path.exists(temp_path): |
|
|
os.remove(temp_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
@app.post("/api/encode-to-base64", response_model=Base64EncodeResponse) |
|
|
async def encode_audio_to_base64( |
|
|
file: UploadFile = File(..., description="Audio file to encode to base64"), |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
""" |
|
|
Upload an audio file and get back its base64 encoded string. |
|
|
Useful for testing the voice detection API. |
|
|
|
|
|
Required Headers: |
|
|
- **x-api-key**: Your API key for authentication |
|
|
|
|
|
Request: |
|
|
- **file**: Audio file to encode (any format) |
|
|
""" |
|
|
try: |
|
|
|
|
|
content = await file.read() |
|
|
|
|
|
|
|
|
audio_base64 = b64.b64encode(content).decode('utf-8') |
|
|
|
|
|
return Base64EncodeResponse( |
|
|
status="success", |
|
|
filename=file.filename or "unknown", |
|
|
fileSize=len(content), |
|
|
base64Length=len(audio_base64), |
|
|
audioBase64=audio_base64 |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail={ |
|
|
"status": "error", |
|
|
"message": f"Error encoding file: {str(e)}", |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |