RJ40under40 commited on
Commit
4b23c1b
·
verified ·
1 Parent(s): 97dd4a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -73
app.py CHANGED
@@ -1,14 +1,10 @@
1
- # ======================================================
2
- # HCL AI VOICE DETECTION API – HACKATHON SUBMISSION
3
- # ======================================================
4
-
5
  import base64
6
  import io
7
  import logging
8
  import numpy as np
9
  import torch
10
- import soundfile as sf
11
  import librosa
 
12
 
13
  from fastapi import FastAPI, HTTPException, Security, Depends
14
  from fastapi.middleware.cors import CORSMiddleware
@@ -17,42 +13,26 @@ from pydantic import BaseModel
17
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
18
 
19
  # ======================================================
20
- # CONFIG & REQUIREMENTS MAPPING
21
  # ======================================================
22
- # The hackathon requires specific classification results
23
- LABEL_MAP = {
24
- 0: "HUMAN",
25
- 1: "AI_GENERATED"
26
- }
27
-
28
  API_KEY_NAME = "access_token"
29
- API_KEY_VALUE = "HCL_SECURE_KEY_2026" # Ensure this matches your submission docs
30
-
31
- # Using a model fine-tuned for Deepfake/Synthetic Voice Detection
32
  MODEL_ID = "melba-t/wav2vec2-fake-speech-detection"
33
  TARGET_SR = 16000
 
34
 
35
- # ======================================================
36
- # INITIALIZATION
37
- # ======================================================
38
  logging.basicConfig(level=logging.INFO)
39
- logger = logging.getLogger("hcl-voice-safety")
40
 
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
- logger.info(f"Loading model to {DEVICE}...")
43
 
44
- try:
45
- feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
46
- model = AutoModelForAudioClassification.from_pretrained(MODEL_ID).to(DEVICE)
47
- model.eval()
48
- logger.info("Model loaded successfully.")
49
- except Exception as e:
50
- logger.error(f"Failed to load model: {e}")
51
 
52
- # ======================================================
53
- # FASTAPI SETUP
54
- # ======================================================
55
  app = FastAPI(title="HCL AI Voice Detection API")
 
56
 
57
  app.add_middleware(
58
  CORSMiddleware,
@@ -64,87 +44,61 @@ app.add_middleware(
64
  class AudioRequest(BaseModel):
65
  audio_base64: str
66
 
67
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
68
-
69
- # ======================================================
70
- # UTILITIES
71
- # ======================================================
72
  async def verify_api_key(api_key: str = Security(api_key_header)):
73
  if api_key != API_KEY_VALUE:
74
  raise HTTPException(status_code=403, detail="Invalid API Key")
75
  return api_key
76
 
77
  def preprocess_audio(b64_string: str):
78
- """Decodes base64 MP3/WAV and converts to 16kHz Mono."""
79
  try:
80
- # Strip header if present (e.g., data:audio/mp3;base64,...)
81
  if "," in b64_string:
82
  b64_string = b64_string.split(",")[1]
83
 
 
 
 
 
84
  audio_bytes = base64.b64decode(b64_string)
85
 
86
- # Use soundfile for reading. Note: For MP3, ensure 'audioread' or 'ffmpeg' is in the environment
 
87
  with io.BytesIO(audio_bytes) as bio:
88
- audio, sr = sf.read(bio)
89
 
90
- # Convert to Mono if Stereo
91
- if len(audio.shape) > 1:
92
- audio = np.mean(audio, axis=1)
93
-
94
- # Resample to 16kHz
95
- if sr != TARGET_SR:
96
- audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=TARGET_SR)
97
-
98
- # Normalization & Padding for stability
99
- audio = np.nan_to_num(audio)
100
  if len(audio) < TARGET_SR:
101
  audio = np.pad(audio, (0, TARGET_SR - len(audio)))
102
 
103
  return audio.astype(np.float32)
104
  except Exception as e:
105
- logger.error(f"Audio processing error: {e}")
106
- raise ValueError("Could not decode audio. Ensure it is a valid Base64 MP3/WAV.")
107
-
108
- # ======================================================
109
- # ENDPOINTS
110
- # ======================================================
111
- @app.get("/health")
112
- def health():
113
- return {"status": "active", "device": DEVICE}
114
 
115
  @app.post("/predict")
116
  async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
117
- """
118
- Analyzes voice sample and classifies as AI_GENERATED or HUMAN.
119
- """
120
  try:
121
- # 1. Preprocess
122
  waveform = preprocess_audio(request.audio_base64)
123
 
124
- # 2. Inference
125
  inputs = feature_extractor(
126
  waveform,
127
  sampling_rate=TARGET_SR,
128
- return_tensors="pt",
129
- padding=True
130
  ).to(DEVICE)
131
 
132
  with torch.inference_mode():
133
  logits = model(**inputs).logits
134
  probs = torch.softmax(logits, dim=-1)
135
 
136
- # 3. Get results
137
  confidence, pred_idx = torch.max(probs, dim=-1)
138
- label = LABEL_MAP.get(int(pred_idx.item()), "UNKNOWN")
139
-
140
- # 4. Return structured JSON
141
  return {
142
- "classification": label,
143
  "confidence_score": round(float(confidence.item()), 4)
144
  }
145
-
146
  except ValueError as ve:
147
  raise HTTPException(status_code=400, detail=str(ve))
148
  except Exception as e:
149
- logger.exception("Prediction failed")
150
- raise HTTPException(status_code=500, detail="Internal server error during analysis")
 
 
 
 
 
 
 
1
  import base64
2
  import io
3
  import logging
4
  import numpy as np
5
  import torch
 
6
  import librosa
7
+ import uvicorn
8
 
9
  from fastapi import FastAPI, HTTPException, Security, Depends
10
  from fastapi.middleware.cors import CORSMiddleware
 
13
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
14
 
15
  # ======================================================
16
+ # CONFIG
17
  # ======================================================
 
 
 
 
 
 
18
  API_KEY_NAME = "access_token"
19
+ API_KEY_VALUE = "HCL_SECURE_KEY_2026"
 
 
20
  MODEL_ID = "melba-t/wav2vec2-fake-speech-detection"
21
  TARGET_SR = 16000
22
+ LABEL_MAP = {0: "HUMAN", 1: "AI_GENERATED"}
23
 
 
 
 
24
  logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger("hcl-api")
26
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
28
 
29
+ # Load Model
30
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
31
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_ID).to(DEVICE)
32
+ model.eval()
 
 
 
33
 
 
 
 
34
  app = FastAPI(title="HCL AI Voice Detection API")
35
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
36
 
37
  app.add_middleware(
38
  CORSMiddleware,
 
44
  class AudioRequest(BaseModel):
45
  audio_base64: str
46
 
 
 
 
 
 
47
  async def verify_api_key(api_key: str = Security(api_key_header)):
48
  if api_key != API_KEY_VALUE:
49
  raise HTTPException(status_code=403, detail="Invalid API Key")
50
  return api_key
51
 
52
  def preprocess_audio(b64_string: str):
 
53
  try:
54
+ # Clean Base64 header and fix padding
55
  if "," in b64_string:
56
  b64_string = b64_string.split(",")[1]
57
 
58
+ missing_padding = len(b64_string) % 4
59
+ if missing_padding:
60
+ b64_string += "=" * (4 - missing_padding)
61
+
62
  audio_bytes = base64.b64decode(b64_string)
63
 
64
+ # Wrap bytes in BytesIO and load with librosa
65
+ # librosa handles MP3 decoding better than soundfile in many Linux envs
66
  with io.BytesIO(audio_bytes) as bio:
67
+ audio, sr = librosa.load(bio, sr=TARGET_SR)
68
 
 
 
 
 
 
 
 
 
 
 
69
  if len(audio) < TARGET_SR:
70
  audio = np.pad(audio, (0, TARGET_SR - len(audio)))
71
 
72
  return audio.astype(np.float32)
73
  except Exception as e:
74
+ logger.error(f"Preprocessing error: {e}")
75
+ raise ValueError(f"Decoding failed: {str(e)}")
 
 
 
 
 
 
 
76
 
77
  @app.post("/predict")
78
  async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
 
 
 
79
  try:
 
80
  waveform = preprocess_audio(request.audio_base64)
81
 
 
82
  inputs = feature_extractor(
83
  waveform,
84
  sampling_rate=TARGET_SR,
85
+ return_tensors="pt"
 
86
  ).to(DEVICE)
87
 
88
  with torch.inference_mode():
89
  logits = model(**inputs).logits
90
  probs = torch.softmax(logits, dim=-1)
91
 
 
92
  confidence, pred_idx = torch.max(probs, dim=-1)
93
+
 
 
94
  return {
95
+ "classification": LABEL_MAP.get(int(pred_idx.item()), "UNKNOWN"),
96
  "confidence_score": round(float(confidence.item()), 4)
97
  }
 
98
  except ValueError as ve:
99
  raise HTTPException(status_code=400, detail=str(ve))
100
  except Exception as e:
101
+ raise HTTPException(status_code=500, detail="Internal Server Error")
102
+
103
+ if __name__ == "__main__":
104
+ uvicorn.run(app, host="0.0.0.0", port=7860)