ROSHANNN123 commited on
Commit
b08f86a
·
verified ·
1 Parent(s): f3fce37

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +26 -0
  2. main.py +87 -0
  3. model_service.py +92 -0
  4. requirements.txt +10 -0
  5. schemas.py +9 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use accurate Python image
2
+ FROM python:3.9
3
+
4
+ # Set working directory to user's home (Hugging Face requirement for permissions)
5
+ WORKDIR /code
6
+
7
+ # Copy requirements file first to leverage cache
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ # Install dependencies
11
+ # Upgrade pip to avoid issues
12
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
13
+
14
+ # Copy the rest of the application
15
+ COPY . /code
16
+
17
+ # Create a writable directory for standard cache if needed (though we use /tmp usually)
18
+ # and set permissions for the user 'user' (ID 1000) which HF uses.
19
+ RUN mkdir -p /code/cache && chmod -R 777 /code/cache
20
+ ENV XDG_CACHE_HOME=/code/cache
21
+
22
+ # Expose the port (Hugging Face expects port 7860)
23
+ EXPOSE 7860
24
+
25
+ # Command to start the uvicorn server on port 7860
26
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import binascii
3
+ from fastapi import FastAPI, HTTPException, Depends, Header
4
+ from schemas import AudioInput, DetectionResult
5
+ from model_service import get_model_service, ModelService
6
+
7
+ app = FastAPI(
8
+ title="AI Voice Detection API",
9
+ description="Detects whether a voice sample is AI-generated or Human-spoken.",
10
+ version="1.0.0"
11
+ )
12
+
13
+ @app.on_event("startup")
14
+ async def startup_event():
15
+ # Initialize model on startup
16
+ get_model_service()
17
+
18
+ API_KEY = "my_secret_key_123" # Simple hardcoded key for submission
19
+
20
+ async def verify_api_key(x_api_key: str = Header(...)):
21
+ if x_api_key != API_KEY:
22
+ raise HTTPException(status_code=401, detail="Invalid API Key")
23
+ return x_api_key
24
+
25
+ from fastapi import FastAPI, HTTPException, Depends, Header, Request
26
+
27
+ # ... (Previous imports stay, schema can stay unused or updated)
28
+
29
+ @app.post("/detect", response_model=DetectionResult)
30
+ async def detect_voice(
31
+ request: Request,
32
+ service: ModelService = Depends(get_model_service),
33
+ api_key: str = Depends(verify_api_key)
34
+ ):
35
+ try:
36
+ # 1. Parse JSON body manually to be flexible
37
+ body = await request.json()
38
+ print(f"DEBUG: Received Body Keys: {list(body.keys())}")
39
+
40
+ # 2. Look for the base64 string in common keys using a priority list
41
+ # OR just grab the first string value that looks like base64
42
+ audio_b64 = None
43
+
44
+ # Check specific keys first
45
+ possible_keys = ["audio_base64", "audio", "data", "file", "encoded_audio", "mp3"]
46
+ for k in possible_keys:
47
+ if k in body and body[k]:
48
+ audio_b64 = body[k]
49
+ print(f"DEBUG: Found audio in key: '{k}'")
50
+ break
51
+
52
+ # Fallback: Search ALL values for a long string
53
+ if not audio_b64:
54
+ for k, v in body.items():
55
+ if isinstance(v, str) and len(v) > 100:
56
+ audio_b64 = v
57
+ print(f"DEBUG: Found audio in generic key: '{k}'")
58
+ break
59
+
60
+ if not audio_b64:
61
+ raise HTTPException(status_code=422, detail=f"Could not find audio data. Received keys: {list(body.keys())}")
62
+
63
+ # Decode Base64 string
64
+ # Handle data URI scheme if present (e.g. "data:audio/mp3;base64,...")
65
+ if "," in audio_b64:
66
+ audio_b64 = audio_b64.split(",")[1]
67
+
68
+ audio_bytes = base64.b64decode(audio_b64)
69
+ except Exception as e:
70
+ print(f"Error parsing request: {e}")
71
+ raise HTTPException(status_code=400, detail=f"Invalid Request: {str(e)}")
72
+
73
+ try:
74
+ label, confidence = service.predict(audio_bytes)
75
+ return DetectionResult(
76
+ label=label,
77
+ confidence=confidence,
78
+ message="Analysis successful"
79
+ )
80
+ except ValueError as ve:
81
+ raise HTTPException(status_code=400, detail=str(ve))
82
+ except Exception as e:
83
+ raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
84
+
85
+ @app.get("/")
86
+ def read_root():
87
+ return {"message": "AI Voice Detection API is running. Use /detect endpoint."}
model_service.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import numpy as np
4
+ import io
5
+ import soundfile as sf
6
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
7
+ import torch.nn.functional as F
8
+
9
+ # Configuration
10
+ MODEL_NAME = "Hemgg/Deepfake-audio-detection" # Using a known fine-tuned model
11
+ # Alternative: "mo-thecreator/Deepfake-audio-detection" if the above fails or is private
12
+ # But usually public models are fine.
13
+
14
+ class ModelService:
15
+ def __init__(self):
16
+ print(f"Loading model: {MODEL_NAME}...")
17
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ try:
19
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
20
+ self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
21
+ print(f"Model loaded on {self.device}")
22
+ except Exception as e:
23
+ print(f"Error loading model: {e}")
24
+ raise e
25
+
26
+ def preprocess_audio(self, audio_bytes):
27
+ """
28
+ Load audio bytes, resample to 16000 Hz (required by Wav2Vec2).
29
+ """
30
+ try:
31
+ # Load audio from bytes
32
+ # librosa.load supports file-like objects
33
+ audio_file = io.BytesIO(audio_bytes)
34
+
35
+ # Load and resample to 16k
36
+ speech, sr = librosa.load(audio_file, sr=16000)
37
+
38
+ # Ensure it's mono (if multi-channel, average them) - librosa.load handles this by default (mono=True)
39
+
40
+ return speech
41
+ except Exception as e:
42
+ print(f"Error processing audio: {e}")
43
+ raise ValueError("Invalid audio format or corrupted file.")
44
+
45
+ def predict(self, audio_bytes):
46
+ speech = self.preprocess_audio(audio_bytes)
47
+
48
+ # Tokenize (extract features)
49
+ inputs = self.feature_extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
50
+ inputs = {key: val.to(self.device) for key, val in inputs.items()}
51
+
52
+ with torch.no_grad():
53
+ logits = self.model(**inputs).logits
54
+
55
+ # Get probabilities
56
+ probs = F.softmax(logits, dim=-1)
57
+
58
+ # The model usually outputs [real, fake] or [fake, real].
59
+ # We need to verify the label mapping.
60
+ # Typically, id2label is stored in the config.
61
+ id2label = self.model.config.id2label
62
+ # Example id2label: {0: 'real', 1: 'fake'} or similar.
63
+
64
+ predicted_id = torch.argmax(probs, dim=-1).item()
65
+ predicted_label = id2label[predicted_id]
66
+ confidence = probs[0][predicted_id].item()
67
+
68
+ # Map to required output format "AI_GENERATED" or "HUMAN"
69
+ # Adjust based on specific model labels.
70
+ # Assuming common labels like "real"/"spoof" or "human"/"ai"
71
+ normalized_label = "UNKNOWN"
72
+
73
+ lower_label = predicted_label.lower()
74
+ if "real" in lower_label or "human" in lower_label or "bonafide" in lower_label:
75
+ normalized_label = "HUMAN"
76
+ elif "fake" in lower_label or "spoof" in lower_label or "ai" in lower_label:
77
+ normalized_label = "AI_GENERATED"
78
+ else:
79
+ # Fallback if labels are obscure, typically 0 is real, 1 is fake for many datasets but not all.
80
+ # We trust the string matching first.
81
+ normalized_label = predicted_label
82
+
83
+ return normalized_label, confidence
84
+
85
+ # Singleton instance
86
+ model_service = None
87
+
88
+ def get_model_service():
89
+ global model_service
90
+ if model_service is None:
91
+ model_service = ModelService()
92
+ return model_service
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ librosa
6
+ soundfile
7
+ python-multipart
8
+ numpy
9
+ requests
10
+ gTTS
schemas.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class AudioInput(BaseModel):
4
+ audio_base64: str
5
+
6
+ class DetectionResult(BaseModel):
7
+ label: str # "AI_GENERATED" or "HUMAN"
8
+ confidence: float
9
+ message: str | None = None