RJ40under40 commited on
Commit
9301dd7
·
verified ·
1 Parent(s): 0c8ad6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -15
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import base64
2
  import io
3
  import logging
@@ -5,15 +6,22 @@ import numpy as np
5
  import torch
6
  import librosa
7
  import uvicorn
 
8
  from fastapi import FastAPI, HTTPException, Security, Depends
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.security.api_key import APIKeyHeader
11
  from pydantic import BaseModel
12
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
13
 
14
- # Config
 
 
15
  API_KEY_NAME = "access_token"
16
  API_KEY_VALUE = "HCL_SECURE_KEY_2026"
 
 
 
 
17
  MODEL_ID = "melba-t/wav2vec2-fake-speech-detection"
18
  TARGET_SR = 16000
19
  LABEL_MAP = {0: "HUMAN", 1: "AI_GENERATED"}
@@ -21,13 +29,34 @@ LABEL_MAP = {0: "HUMAN", 1: "AI_GENERATED"}
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger("hcl-api")
23
 
24
- # Initialize Model
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
- feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
27
- model = AutoModelForAudioClassification.from_pretrained(MODEL_ID).to(DEVICE)
28
- model.eval()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  app = FastAPI(title="HCL AI Voice Detection API")
 
31
 
32
  app.add_middleware(
33
  CORSMiddleware,
@@ -39,8 +68,6 @@ app.add_middleware(
39
  class AudioRequest(BaseModel):
40
  audio_base64: str
41
 
42
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
43
-
44
  async def verify_api_key(api_key: str = Security(api_key_header)):
45
  if api_key != API_KEY_VALUE:
46
  raise HTTPException(status_code=403, detail="Invalid API Key")
@@ -51,14 +78,14 @@ def preprocess_audio(b64_string: str):
51
  if "," in b64_string:
52
  b64_string = b64_string.split(",")[1]
53
 
54
- # Correct padding
55
  missing_padding = len(b64_string) % 4
56
  if missing_padding:
57
  b64_string += "=" * (4 - missing_padding)
58
 
59
  audio_bytes = base64.b64decode(b64_string)
60
 
61
- # Load via librosa for better MP3 compatibility
62
  with io.BytesIO(audio_bytes) as bio:
63
  audio, sr = librosa.load(bio, sr=TARGET_SR)
64
 
@@ -72,13 +99,21 @@ def preprocess_audio(b64_string: str):
72
 
73
  @app.get("/")
74
  def home():
75
- return {"message": "API is running. Visit /docs for Swagger UI"}
76
 
77
  @app.post("/predict")
78
  async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
 
 
 
79
  try:
80
- waveform = preprocess_audio(request.audio_base_64)
81
- inputs = feature_extractor(waveform, sampling_rate=TARGET_SR, return_tensors="pt").to(DEVICE)
 
 
 
 
 
82
 
83
  with torch.inference_mode():
84
  logits = model(**inputs).logits
@@ -86,15 +121,19 @@ async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
86
 
87
  confidence, pred_idx = torch.max(probs, dim=-1)
88
 
 
 
 
89
  return {
90
- "classification": LABEL_MAP.get(int(pred_idx.item()), "UNKNOWN"),
91
  "confidence_score": round(float(confidence.item()), 4)
92
  }
93
  except ValueError as ve:
94
  raise HTTPException(status_code=400, detail=str(ve))
95
  except Exception as e:
96
  logger.error(f"Prediction error: {e}")
97
- raise HTTPException(status_code=500, detail="Internal Server Error")
98
 
99
  if __name__ == "__main__":
100
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
 
 
1
+ import os
2
  import base64
3
  import io
4
  import logging
 
6
  import torch
7
  import librosa
8
  import uvicorn
9
+
10
  from fastapi import FastAPI, HTTPException, Security, Depends
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.security.api_key import APIKeyHeader
13
  from pydantic import BaseModel
14
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
15
 
16
+ # ======================================================
17
+ # CONFIG & SECRETS
18
+ # ======================================================
19
  API_KEY_NAME = "access_token"
20
  API_KEY_VALUE = "HCL_SECURE_KEY_2026"
21
+
22
+ # Get your Hugging Face token from the Space's Secret settings
23
+ HF_TOKEN = os.getenv("HF_Token")
24
+
25
  MODEL_ID = "melba-t/wav2vec2-fake-speech-detection"
26
  TARGET_SR = 16000
27
  LABEL_MAP = {0: "HUMAN", 1: "AI_GENERATED"}
 
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger("hcl-api")
31
 
 
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
33
 
34
+ # ======================================================
35
+ # MODEL INITIALIZATION (WITH AUTH)
36
+ # ======================================================
37
+ try:
38
+ logger.info(f"Loading private model {MODEL_ID}...")
39
+ # Passing the token allows access to the private/restricted repo
40
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
41
+ MODEL_ID,
42
+ token=HF_TOKEN
43
+ )
44
+ model = AutoModelForAudioClassification.from_pretrained(
45
+ MODEL_ID,
46
+ token=HF_TOKEN
47
+ ).to(DEVICE)
48
+ model.eval()
49
+ logger.info("Model loaded successfully.")
50
+ except Exception as e:
51
+ logger.error(f"Error loading model: {e}")
52
+ # Fallback to prevent app crash if token is missing
53
+ model = None
54
+
55
+ # ======================================================
56
+ # FASTAPI APP
57
+ # ======================================================
58
  app = FastAPI(title="HCL AI Voice Detection API")
59
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
60
 
61
  app.add_middleware(
62
  CORSMiddleware,
 
68
  class AudioRequest(BaseModel):
69
  audio_base64: str
70
 
 
 
71
  async def verify_api_key(api_key: str = Security(api_key_header)):
72
  if api_key != API_KEY_VALUE:
73
  raise HTTPException(status_code=403, detail="Invalid API Key")
 
78
  if "," in b64_string:
79
  b64_string = b64_string.split(",")[1]
80
 
81
+ # Standardize padding
82
  missing_padding = len(b64_string) % 4
83
  if missing_padding:
84
  b64_string += "=" * (4 - missing_padding)
85
 
86
  audio_bytes = base64.b64decode(b64_string)
87
 
88
+ # Load audio using librosa (requires ffmpeg in packages.txt)
89
  with io.BytesIO(audio_bytes) as bio:
90
  audio, sr = librosa.load(bio, sr=TARGET_SR)
91
 
 
99
 
100
  @app.get("/")
101
  def home():
102
+ return {"message": "HCL Voice Detection API Active. Visit /docs"}
103
 
104
  @app.post("/predict")
105
  async def predict(request: AudioRequest, _: str = Depends(verify_api_key)):
106
+ if model is None:
107
+ raise HTTPException(status_code=503, detail="Model not loaded. Check HF_Token.")
108
+
109
  try:
110
+ waveform = preprocess_audio(request.audio_base64)
111
+
112
+ inputs = feature_extractor(
113
+ waveform,
114
+ sampling_rate=TARGET_SR,
115
+ return_tensors="pt"
116
+ ).to(DEVICE)
117
 
118
  with torch.inference_mode():
119
  logits = model(**inputs).logits
 
121
 
122
  confidence, pred_idx = torch.max(probs, dim=-1)
123
 
124
+ # Map prediction to required hackathon labels
125
+ label = LABEL_MAP.get(int(pred_idx.item()), "UNKNOWN")
126
+
127
  return {
128
+ "classification": label,
129
  "confidence_score": round(float(confidence.item()), 4)
130
  }
131
  except ValueError as ve:
132
  raise HTTPException(status_code=400, detail=str(ve))
133
  except Exception as e:
134
  logger.error(f"Prediction error: {e}")
135
+ raise HTTPException(status_code=500, detail="Inference Error")
136
 
137
  if __name__ == "__main__":
138
+ # Port 7860 is required for Hugging Face Spaces
139
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)