RJ40under40 commited on
Commit
97dd4a0
·
verified ·
1 Parent(s): 7e73c0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -74
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # ======================================================
2
- # HCL AI VOICE DETECTION API – CRASH-PROOF VERSION
3
  # ======================================================
4
 
5
  import base64
@@ -10,45 +10,50 @@ import torch
10
  import soundfile as sf
11
  import librosa
12
 
13
- from fastapi import FastAPI, HTTPException, Depends, Security
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from fastapi.security.api_key import APIKeyHeader
16
  from pydantic import BaseModel
17
-
18
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
19
 
20
  # ======================================================
21
- # CONFIG
22
  # ======================================================
 
 
 
 
 
 
23
  API_KEY_NAME = "access_token"
24
- API_KEY_VALUE = "HCL_SECURE_KEY_2026"
25
 
26
- MODEL_ID = "superb/wav2vec2-base-superb-ks"
 
27
  TARGET_SR = 16000
28
 
29
  # ======================================================
30
- # LOGGING
31
  # ======================================================
32
  logging.basicConfig(level=logging.INFO)
33
- logger = logging.getLogger("voice-detection")
34
 
35
- # ======================================================
36
- # DEVICE & MODEL
37
- # ======================================================
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
- logger.info(f"Using device: {DEVICE}")
40
 
41
- feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
42
- model = AutoModelForAudioClassification.from_pretrained(MODEL_ID).to(DEVICE)
43
- model.eval()
 
 
 
 
44
 
45
  # ======================================================
46
- # FASTAPI APP
47
  # ======================================================
48
  app = FastAPI(title="HCL AI Voice Detection API")
49
 
50
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
51
-
52
  app.add_middleware(
53
  CORSMiddleware,
54
  allow_origins=["*"],
@@ -56,85 +61,90 @@ app.add_middleware(
56
  allow_headers=["*"],
57
  )
58
 
59
- # ======================================================
60
- # SCHEMA
61
- # ======================================================
62
  class AudioRequest(BaseModel):
63
  audio_base64: str
64
 
 
 
65
  # ======================================================
66
- # SECURITY
67
  # ======================================================
68
  async def verify_api_key(api_key: str = Security(api_key_header)):
69
  if api_key != API_KEY_VALUE:
70
  raise HTTPException(status_code=403, detail="Invalid API Key")
71
  return api_key
72
 
73
- # ======================================================
74
- # AUDIO DECODING (SAFE)
75
- # ======================================================
76
- def decode_audio(b64_audio: str):
77
- audio_bytes = base64.b64decode(b64_audio.split(",")[-1])
78
- audio, sr = sf.read(io.BytesIO(audio_bytes))
79
-
80
- if audio.ndim > 1:
81
- audio = np.mean(audio, axis=1)
82
-
83
- if sr != TARGET_SR:
84
- audio = librosa.resample(audio.astype(float), sr, TARGET_SR)
85
-
86
- audio = np.nan_to_num(audio)
87
-
88
- if len(audio) < TARGET_SR:
89
- audio = np.pad(audio, (0, TARGET_SR - len(audio)))
90
-
91
- return audio.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # ======================================================
94
- # INFERENCE (CRASH-PROOF)
95
  # ======================================================
96
- def analyze_voice(audio):
 
 
 
 
 
 
 
 
97
  try:
 
 
 
 
98
  inputs = feature_extractor(
99
- audio,
100
- sampling_rate=TARGET_SR,
101
- return_tensors="pt",
102
  padding=True
103
- )
104
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
105
 
106
  with torch.inference_mode():
107
  logits = model(**inputs).logits
108
  probs = torch.softmax(logits, dim=-1)
109
 
110
- score, pred = torch.max(probs, dim=-1)
 
 
111
 
 
112
  return {
113
- "classification": "UNKNOWN",
114
- "confidence_score": round(score.item(), 4),
115
- "raw_label_index": int(pred.item())
116
  }
117
 
 
 
118
  except Exception as e:
119
- logger.exception("Model inference failed")
120
- return {
121
- "classification": "MODEL_ERROR",
122
- "confidence_score": 0.0,
123
- "error": str(e)
124
- }
125
-
126
- # ======================================================
127
- # ENDPOINTS
128
- # ======================================================
129
- @app.get("/health")
130
- def health():
131
- return {"status": "ok", "device": DEVICE}
132
-
133
- @app.post("/predict")
134
- async def predict(
135
- request: AudioRequest,
136
- _: str = Depends(verify_api_key)
137
- ):
138
- audio = decode_audio(request.audio_base64)
139
- result = analyze_voice(audio)
140
- return result
 
1
  # ======================================================
2
+ # HCL AI VOICE DETECTION API – HACKATHON SUBMISSION
3
  # ======================================================
4
 
5
  import base64
 
10
  import soundfile as sf
11
  import librosa
12
 
13
+ from fastapi import FastAPI, HTTPException, Security, Depends
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from fastapi.security.api_key import APIKeyHeader
16
  from pydantic import BaseModel
 
17
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
18
 
19
  # ======================================================
20
+ # CONFIG & REQUIREMENTS MAPPING
21
  # ======================================================
22
+ # The hackathon requires specific classification results
23
+ LABEL_MAP = {
24
+ 0: "HUMAN",
25
+ 1: "AI_GENERATED"
26
+ }
27
+
28
  API_KEY_NAME = "access_token"
29
+ API_KEY_VALUE = "HCL_SECURE_KEY_2026" # Ensure this matches your submission docs
30
 
31
+ # Using a model fine-tuned for Deepfake/Synthetic Voice Detection
32
+ MODEL_ID = "melba-t/wav2vec2-fake-speech-detection"
33
  TARGET_SR = 16000
34
 
35
  # ======================================================
36
+ # INITIALIZATION
37
  # ======================================================
38
  logging.basicConfig(level=logging.INFO)
39
+ logger = logging.getLogger("hcl-voice-safety")
40
 
 
 
 
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
+ logger.info(f"Loading model to {DEVICE}...")
43
 
44
+ try:
45
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
46
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_ID).to(DEVICE)
47
+ model.eval()
48
+ logger.info("Model loaded successfully.")
49
+ except Exception as e:
50
+ logger.error(f"Failed to load model: {e}")
51
 
52
  # ======================================================
53
+ # FASTAPI SETUP
54
  # ======================================================
55
  app = FastAPI(title="HCL AI Voice Detection API")
56
 
 
 
57
  app.add_middleware(
58
  CORSMiddleware,
59
  allow_origins=["*"],
 
61
  allow_headers=["*"],
62
  )
63
 
 
 
 
64
  class AudioRequest(BaseModel):
65
  audio_base64: str
66
 
67
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
68
+
69
  # ======================================================
70
+ # UTILITIES
71
  # ======================================================
72
  async def verify_api_key(api_key: str = Security(api_key_header)):
73
  if api_key != API_KEY_VALUE:
74
  raise HTTPException(status_code=403, detail="Invalid API Key")
75
  return api_key
76
 
77
+ def preprocess_audio(b64_string: str):
78
+ """Decodes base64 MP3/WAV and converts to 16kHz Mono."""
79
+ try:
80
+ # Strip header if present (e.g., data:audio/mp3;base64,...)
81
+ if "," in b64_string:
82
+ b64_string = b64_string.split(",")[1]
83
+
84
+ audio_bytes = base64.b64decode(b64_string)
85
+
86
+ # Use soundfile for reading. Note: For MP3, ensure 'audioread' or 'ffmpeg' is in the environment
87
+ with io.BytesIO(audio_bytes) as bio:
88
+ audio, sr = sf.read(bio)
89
+
90
+ # Convert to Mono if Stereo
91
+ if len(audio.shape) > 1:
92
+ audio = np.mean(audio, axis=1)
93
+
94
+ # Resample to 16kHz
95
+ if sr != TARGET_SR:
96
+ audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=TARGET_SR)
97
+
98
+ # Normalization & Padding for stability
99
+ audio = np.nan_to_num(audio)
100
+ if len(audio) < TARGET_SR:
101
+ audio = np.pad(audio, (0, TARGET_SR - len(audio)))
102
+
103
+ return audio.astype(np.float32)
104
+ except Exception as e:
105
+ logger.error(f"Audio processing error: {e}")
106
+ raise ValueError("Could not decode audio. Ensure it is a valid Base64 MP3/WAV.")
107
 
108
  # ======================================================
109
+ # ENDPOINTS
110
  # ======================================================
111
+ @app.get("/health")
112
+ def health():
113
+ return {"status": "active", "device": DEVICE}
114
+
115
+ @app.post("/predict")
116
+ async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
117
+ """
118
+ Analyzes voice sample and classifies as AI_GENERATED or HUMAN.
119
+ """
120
  try:
121
+ # 1. Preprocess
122
+ waveform = preprocess_audio(request.audio_base64)
123
+
124
+ # 2. Inference
125
  inputs = feature_extractor(
126
+ waveform,
127
+ sampling_rate=TARGET_SR,
128
+ return_tensors="pt",
129
  padding=True
130
+ ).to(DEVICE)
 
131
 
132
  with torch.inference_mode():
133
  logits = model(**inputs).logits
134
  probs = torch.softmax(logits, dim=-1)
135
 
136
+ # 3. Get results
137
+ confidence, pred_idx = torch.max(probs, dim=-1)
138
+ label = LABEL_MAP.get(int(pred_idx.item()), "UNKNOWN")
139
 
140
+ # 4. Return structured JSON
141
  return {
142
+ "classification": label,
143
+ "confidence_score": round(float(confidence.item()), 4)
 
144
  }
145
 
146
+ except ValueError as ve:
147
+ raise HTTPException(status_code=400, detail=str(ve))
148
  except Exception as e:
149
+ logger.exception("Prediction failed")
150
+ raise HTTPException(status_code=500, detail="Internal server error during analysis")