RJ40under40 commited on
Commit
f3ff9bf
·
verified ·
1 Parent(s): 258be08

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ======================================================
2
+ # HCL AI VOICE DETECTION API
3
+ # Hugging Face Spaces (FastAPI)
4
+ # ======================================================
5
+
6
+ import base64
7
+ import io
8
+ import logging
9
+ import librosa
10
+ import torch
11
+
12
+ from fastapi import FastAPI, HTTPException, Depends, Security
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.security.api_key import APIKeyHeader
15
+ from pydantic import BaseModel
16
+
17
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
18
+
19
+ # ======================================================
20
+ # CONFIGURATION
21
+ # ======================================================
22
+ API_KEY_NAME = "access_token"
23
+ API_KEY_VALUE = "HCL_SECURE_KEY_2026"
24
+
25
+ MODEL_ID = "facebook/wav2vec2-base-960h"
26
+ TARGET_SR = 16000
27
+
28
+ # ======================================================
29
+ # LOGGING
30
+ # ======================================================
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger("voice-detection")
33
+
34
+ # ======================================================
35
+ # DEVICE & MODEL LOADING (RUNS ON STARTUP)
36
+ # ======================================================
37
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
+ logger.info(f"Using device: {DEVICE}")
39
+
40
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
41
+ model = AutoModelForAudioClassification.from_pretrained(
42
+ MODEL_ID,
43
+ num_labels=2
44
+ ).to(DEVICE)
45
+
46
+ model.eval()
47
+ logger.info("Model loaded successfully")
48
+
49
+ # ======================================================
50
+ # FASTAPI APP
51
+ # ======================================================
52
+ app = FastAPI(
53
+ title="HCL AI Voice Detection API",
54
+ version="1.0.0"
55
+ )
56
+
57
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
58
+
59
+ app.add_middleware(
60
+ CORSMiddleware,
61
+ allow_origins=["*"],
62
+ allow_methods=["*"],
63
+ allow_headers=["*"],
64
+ )
65
+
66
+ # ======================================================
67
+ # SCHEMAS
68
+ # ======================================================
69
+ class AudioRequest(BaseModel):
70
+ audio_base64: str
71
+
72
+
73
+ class PredictionResponse(BaseModel):
74
+ classification: str
75
+ confidence_score: float
76
+
77
+
78
+ # ======================================================
79
+ # SECURITY
80
+ # ======================================================
81
+ async def verify_api_key(api_key: str = Security(api_key_header)):
82
+ if api_key != API_KEY_VALUE:
83
+ raise HTTPException(status_code=403, detail="Invalid API Key")
84
+ return api_key
85
+
86
+
87
+ # ======================================================
88
+ # CORE LOGIC
89
+ # ======================================================
90
+ def decode_audio(b64_audio: str) -> bytes:
91
+ try:
92
+ return base64.b64decode(b64_audio.split(",")[-1])
93
+ except Exception:
94
+ raise HTTPException(status_code=400, detail="Invalid Base64 audio")
95
+
96
+
97
+ def analyze_voice(audio_bytes: bytes) -> tuple[str, float]:
98
+ audio, _ = librosa.load(
99
+ io.BytesIO(audio_bytes),
100
+ sr=TARGET_SR,
101
+ mono=True
102
+ )
103
+
104
+ inputs = feature_extractor(
105
+ audio,
106
+ sampling_rate=TARGET_SR,
107
+ return_tensors="pt"
108
+ )
109
+
110
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
111
+
112
+ with torch.inference_mode():
113
+ logits = model(**inputs).logits
114
+ probs = torch.softmax(logits, dim=-1)
115
+
116
+ confidence, prediction = torch.max(probs, dim=-1)
117
+ label = "AI_GENERATED" if prediction.item() == 1 else "HUMAN"
118
+
119
+ return label, round(confidence.item(), 4)
120
+
121
+
122
+ # ======================================================
123
+ # ENDPOINTS
124
+ # ======================================================
125
+ @app.get("/health")
126
+ def health():
127
+ return {"status": "ok", "device": DEVICE}
128
+
129
+
130
+ @app.post(
131
+ "/predict",
132
+ response_model=PredictionResponse
133
+ )
134
+ async def predict(
135
+ request: AudioRequest,
136
+ _: str = Depends(verify_api_key)
137
+ ):
138
+ audio_bytes = decode_audio(request.audio_base64)
139
+ label, score = analyze_voice(audio_bytes)
140
+
141
+ return {
142
+ "classification": label,
143
+ "confidence_score": score
144
+ }