RJ40under40 commited on
Commit
0c8ad6a
·
verified ·
1 Parent(s): 463ac7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -18
app.py CHANGED
@@ -5,16 +5,13 @@ import numpy as np
5
  import torch
6
  import librosa
7
  import uvicorn
8
-
9
  from fastapi import FastAPI, HTTPException, Security, Depends
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.security.api_key import APIKeyHeader
12
  from pydantic import BaseModel
13
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
14
 
15
- # ======================================================
16
- # CONFIG
17
- # ======================================================
18
  API_KEY_NAME = "access_token"
19
  API_KEY_VALUE = "HCL_SECURE_KEY_2026"
20
  MODEL_ID = "melba-t/wav2vec2-fake-speech-detection"
@@ -24,15 +21,13 @@ LABEL_MAP = {0: "HUMAN", 1: "AI_GENERATED"}
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger("hcl-api")
26
 
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
-
29
- # Load Model
30
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
31
  model = AutoModelForAudioClassification.from_pretrained(MODEL_ID).to(DEVICE)
32
  model.eval()
33
 
34
  app = FastAPI(title="HCL AI Voice Detection API")
35
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
36
 
37
  app.add_middleware(
38
  CORSMiddleware,
@@ -44,6 +39,8 @@ app.add_middleware(
44
  class AudioRequest(BaseModel):
45
  audio_base64: str
46
 
 
 
47
  async def verify_api_key(api_key: str = Security(api_key_header)):
48
  if api_key != API_KEY_VALUE:
49
  raise HTTPException(status_code=403, detail="Invalid API Key")
@@ -51,18 +48,17 @@ async def verify_api_key(api_key: str = Security(api_key_header)):
51
 
52
  def preprocess_audio(b64_string: str):
53
  try:
54
- # Clean Base64 header and fix padding
55
  if "," in b64_string:
56
  b64_string = b64_string.split(",")[1]
57
 
 
58
  missing_padding = len(b64_string) % 4
59
  if missing_padding:
60
  b64_string += "=" * (4 - missing_padding)
61
 
62
  audio_bytes = base64.b64decode(b64_string)
63
 
64
- # Wrap bytes in BytesIO and load with librosa
65
- # librosa handles MP3 decoding better than soundfile in many Linux envs
66
  with io.BytesIO(audio_bytes) as bio:
67
  audio, sr = librosa.load(bio, sr=TARGET_SR)
68
 
@@ -74,16 +70,15 @@ def preprocess_audio(b64_string: str):
74
  logger.error(f"Preprocessing error: {e}")
75
  raise ValueError(f"Decoding failed: {str(e)}")
76
 
 
 
 
 
77
  @app.post("/predict")
78
  async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
79
  try:
80
- waveform = preprocess_audio(request.audio_base64)
81
-
82
- inputs = feature_extractor(
83
- waveform,
84
- sampling_rate=TARGET_SR,
85
- return_tensors="pt"
86
- ).to(DEVICE)
87
 
88
  with torch.inference_mode():
89
  logits = model(**inputs).logits
@@ -98,7 +93,8 @@ async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
98
  except ValueError as ve:
99
  raise HTTPException(status_code=400, detail=str(ve))
100
  except Exception as e:
 
101
  raise HTTPException(status_code=500, detail="Internal Server Error")
102
 
103
  if __name__ == "__main__":
104
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
5
  import torch
6
  import librosa
7
  import uvicorn
 
8
  from fastapi import FastAPI, HTTPException, Security, Depends
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.security.api_key import APIKeyHeader
11
  from pydantic import BaseModel
12
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
13
 
14
+ # Config
 
 
15
  API_KEY_NAME = "access_token"
16
  API_KEY_VALUE = "HCL_SECURE_KEY_2026"
17
  MODEL_ID = "melba-t/wav2vec2-fake-speech-detection"
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger("hcl-api")
23
 
24
+ # Initialize Model
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
26
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
27
  model = AutoModelForAudioClassification.from_pretrained(MODEL_ID).to(DEVICE)
28
  model.eval()
29
 
30
  app = FastAPI(title="HCL AI Voice Detection API")
 
31
 
32
  app.add_middleware(
33
  CORSMiddleware,
 
39
  class AudioRequest(BaseModel):
40
  audio_base64: str
41
 
42
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
43
+
44
  async def verify_api_key(api_key: str = Security(api_key_header)):
45
  if api_key != API_KEY_VALUE:
46
  raise HTTPException(status_code=403, detail="Invalid API Key")
 
48
 
49
  def preprocess_audio(b64_string: str):
50
  try:
 
51
  if "," in b64_string:
52
  b64_string = b64_string.split(",")[1]
53
 
54
+ # Correct padding
55
  missing_padding = len(b64_string) % 4
56
  if missing_padding:
57
  b64_string += "=" * (4 - missing_padding)
58
 
59
  audio_bytes = base64.b64decode(b64_string)
60
 
61
+ # Load via librosa for better MP3 compatibility
 
62
  with io.BytesIO(audio_bytes) as bio:
63
  audio, sr = librosa.load(bio, sr=TARGET_SR)
64
 
 
70
  logger.error(f"Preprocessing error: {e}")
71
  raise ValueError(f"Decoding failed: {str(e)}")
72
 
73
+ @app.get("/")
74
+ def home():
75
+ return {"message": "API is running. Visit /docs for Swagger UI"}
76
+
77
  @app.post("/predict")
78
  async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
79
  try:
80
+ waveform = preprocess_audio(request.audio_base_64)
81
+ inputs = feature_extractor(waveform, sampling_rate=TARGET_SR, return_tensors="pt").to(DEVICE)
 
 
 
 
 
82
 
83
  with torch.inference_mode():
84
  logits = model(**inputs).logits
 
93
  except ValueError as ve:
94
  raise HTTPException(status_code=400, detail=str(ve))
95
  except Exception as e:
96
+ logger.error(f"Prediction error: {e}")
97
  raise HTTPException(status_code=500, detail="Internal Server Error")
98
 
99
  if __name__ == "__main__":
100
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)