File size: 3,900 Bytes
151ed35
 
 
 
 
 
 
 
4e86f46
151ed35
7fb32a3
151ed35
4e86f46
 
151ed35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af0b8e
 
 
 
151ed35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import uvicorn
import os
import torchaudio
import torch.nn.functional as F
import torch
from transformers import AutoProcessor, AutoModelForAudioClassification
from pathlib import Path

app_dir = Path(__file__).parent
# Model setup
model_path = app_dir / "Deepfake" / "model"

processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForAudioClassification.from_pretrained(
    pretrained_model_name_or_path=model_path,
    local_files_only=True,
)

def prepare_audio(file_path, sampling_rate=16000, duration=10):
    """
    Prepares audio by loading, resampling, and returning it in manageable chunks.
    """
    # Load and resample the audio file
    waveform, original_sampling_rate = torchaudio.load(file_path)
    
    # Convert stereo to mono if necessary
    if waveform.shape[0] > 1:  # More than 1 channel
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Resample if needed
    if original_sampling_rate != sampling_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=sampling_rate)
        waveform = resampler(waveform)
    
    # Calculate chunk size in samples
    chunk_size = sampling_rate * duration
    audio_chunks = []

    # Split the audio into chunks
    for start in range(0, waveform.shape[1], chunk_size):
        chunk = waveform[:, start:start + chunk_size]
        
        # Pad the last chunk if it's shorter than the chunk size
        if chunk.shape[1] < chunk_size:
            padding = chunk_size - chunk.shape[1]
            chunk = torch.nn.functional.pad(chunk, (0, padding))
        
        audio_chunks.append(chunk.squeeze().numpy())
    
    return audio_chunks

def predict_audio(file_path):
    """
    Predicts the class of an audio file by aggregating predictions from chunks and calculates confidence.
    """
    # Prepare audio chunks
    audio_chunks = prepare_audio(file_path)
    predictions = []
    confidences = []

    for i, chunk in enumerate(audio_chunks):
        # Prepare input for the model
        inputs = processor(
            chunk, sampling_rate=16000, return_tensors="pt", padding=True
        )
        
        # Perform inference
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            
            # Apply softmax to calculate probabilities
            probabilities = F.softmax(logits, dim=1)
            
            # Get the predicted class and its confidence
            confidence, predicted_class = torch.max(probabilities, dim=1)
            predictions.append(predicted_class.item())
            confidences.append(confidence.item())
    
    # Aggregate predictions (majority voting)
    aggregated_prediction_id = max(set(predictions), key=predictions.count)
    predicted_label = model.config.id2label[aggregated_prediction_id]
    
    # Calculate average confidence across chunks
    average_confidence = sum(confidences) / len(confidences)

    return {
        "predicted_label": predicted_label,
        "average_confidence": average_confidence
    }

# Initialize FastAPI
app = FastAPI()

@app.post("/infer")
async def infer(file: UploadFile = File(...)):
    """
    Accepts an audio file and returns the prediction and confidence.
    """
    # Save the uploaded file to a temporary location
    temp_file_path = f"temp_{file.filename}"
    with open(temp_file_path, "wb") as temp_file:
        temp_file.write(await file.read())
    
    try:
        # Perform inference
        predictions = predict_audio(temp_file_path)
    finally:
        # Clean up the temporary file
        os.remove(temp_file_path)
    
    return predictions

@app.get("/health")
async def health():
    return {
        "message": "ok",
        "Sound":str(torchaudio.list_audio_backends())
        }