File size: 7,839 Bytes
63e1917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import os
import contextlib
import wave
import librosa
import numpy as np
import pandas as pd
import parselmouth
import soundfile as sf
import webrtcvad
from tensorflow.keras.models import load_model
import joblib
import warnings
import tempfile

# --- FastAPI Imports ---
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse

# --- Configuration ---
TARGET_SR = 16000
MODEL_PATH = "vocal_model.h5"
SCALER_PATH = "vocal_scaler.joblib"
FEATURES_PATH = "feature_names.joblib"

# --- Suppress Warnings ---
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# --- Load Models and Scaler at Startup ---
# This is efficient as they are loaded only once when the app starts
try:
    model = load_model(MODEL_PATH)
    scaler = joblib.load(SCALER_PATH)
    feature_names = joblib.load(FEATURES_PATH)
    print("✅ Model, scaler, and feature list loaded successfully.")
except Exception as e:
    print(f"❌ FATAL ERROR: Could not load model files. The application will not work.")
    print(f"   Details: {e}")
    # In a real-world scenario, you might want the app to fail to start here.
    model, scaler, feature_names = None, None, None

# --- Feature Extraction Functions (Copied from your script) ---
# (I've omitted the functions for brevity, but you should copy ALL of them here)
# - preprocess_audio
# - extract_features
# ... (all your existing helper functions) ...
def preprocess_audio(input_path, target_sr=TARGET_SR):
    try:
        data, sr = librosa.load(input_path, sr=None, mono=False)
        if data.ndim > 1: data = data.mean(axis=0)
        if sr != target_sr: data = librosa.resample(data, orig_sr=sr, target_sr=target_sr)
        base, ext = os.path.splitext(input_path)
        output_path = f"{base}_processed_for_prediction.wav"
        sf.write(output_path, data, target_sr, subtype='PCM_16')
        return output_path
    except Exception as e:
        print(f"Error preprocessing {input_path}: {e}")
        return None

def extract_features(file_path):
    try:
        y, sr = librosa.load(file_path, sr=None)
        duration = librosa.get_duration(y=y, sr=sr)
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
        mfcc_means = np.mean(mfccs, axis=1)

        snd = parselmouth.Sound(file_path)
        pitch = snd.to_pitch()
        pitch_values = pitch.selected_array['frequency']
        pitch_values = pitch_values[pitch_values != 0]

        pitch_mean = np.mean(pitch_values) if len(pitch_values) > 0 else 0
        pitch_std = np.std(pitch_values) if len(pitch_values) > 0 else 0

        point_process = parselmouth.praat.call(snd, "To PointProcess (periodic, cc)", 75, 500)
        jitter_local = parselmouth.praat.call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
        shimmer_local = parselmouth.praat.call([snd, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)

        def read_wave(path):
            with contextlib.closing(wave.open(path, 'rb')) as wf:
                pcm_data, sample_rate = wf.readframes(wf.getnframes()), wf.getframerate()
                return pcm_data, sample_rate
        
        def frame_generator(frame_duration_ms, audio, sample_rate):
            n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
            offset = 0
            while offset + n < len(audio):
                yield audio[offset:offset + n]
                offset += n
        
        vad = webrtcvad.Vad(1)
        audio, sample_rate = read_wave(file_path)
        frames = list(frame_generator(30, audio, sample_rate))
        voiced_seconds = 0
        num_segments = 0
        if frames:
            for frame in frames:
                if vad.is_speech(frame, sample_rate):
                    voiced_seconds += 0.03 # 30ms frame
                    num_segments +=1

        silence_ratio = max(0, (duration - voiced_seconds) / duration) if duration > 0 else 0
        speaking_rate = num_segments / duration if duration > 0 else 0

        features = {
            'Duration': duration,
            'Pitch_Mean': pitch_mean,
            'Pitch_Std': pitch_std,
            'Jitter': jitter_local,
            'Shimmer': shimmer_local,
            'Speaking_Rate': speaking_rate,
            'Silence_Ratio': silence_ratio,
        }
        for idx, val in enumerate(mfcc_means):
            features[f'MFCC_{idx+1}'] = val
            
        return features

    except Exception as e:
        print(f"Error extracting features from {file_path}: {e}")
        return None

# --- Main Prediction Logic (Refactored to return a dictionary) ---

def predict_from_audio_path(file_path):
    """
    Takes a file path, runs the full prediction pipeline, and returns a result dictionary.
    """
    if not all([model, scaler, feature_names]):
        raise HTTPException(status_code=503, detail="Model is not loaded or available.")

    # 1. Preprocess audio
    processed_path = preprocess_audio(file_path)
    if not processed_path:
        raise HTTPException(status_code=400, detail="Audio preprocessing failed.")

    # 2. Extract features
    features_dict = extract_features(processed_path)
    if not features_dict:
        os.remove(processed_path)
        raise HTTPException(status_code=400, detail="Feature extraction failed.")

    try:
        # 3. Convert to DataFrame and ensure correct column order
        feature_df = pd.DataFrame([features_dict])
        feature_df = feature_df[feature_names] # Crucial step!

        # 4. Scale features
        scaled_features = scaler.transform(feature_df)

        # 5. Make a prediction
        prediction_prob = model.predict(scaled_features, verbose=0)[0][0]
        prediction_label = int((prediction_prob > 0.5).astype("int32"))

        # 6. Format the result
        result_text = "Parkinson's Detected" if prediction_label == 1 else "Healthy"
        
        # 7. Cleanup the temporary processed file
        os.remove(processed_path)
        
        return {
            "status": "success",
            "prediction": result_text,
            "confidence": float(prediction_prob),
            "label": prediction_label
        }
    except Exception as e:
        # Ensure cleanup even if an error occurs after file creation
        os.remove(processed_path)
        raise HTTPException(status_code=500, detail=f"An error occurred during prediction: {str(e)}")


# --- FastAPI App Definition ---

app = FastAPI(
    title="Parkinson's Voice Detection API",
    description="An API that uses a deep learning model to predict the presence of Parkinson's disease from a voice recording.",
    version="1.0"
)

@app.get("/", tags=["General"])
def read_root():
    """A welcome message to check if the API is running."""
    return {"message": "Welcome to the Parkinson's Voice Prediction API. Go to /docs for usage."}

@app.post("/predict/", tags=["Prediction"])
async def create_prediction(file: UploadFile = File(...)):
    """
    Accepts an audio file, processes it, and returns the prediction result.
    The audio file can be in any format that librosa supports (wav, mp3, etc.).
    """
    # Save the uploaded file to a temporary location on the server
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp_file:
            content = await file.read()
            tmp_file.write(content)
            tmp_file_path = tmp_file.name
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error handling the uploaded file: {e}")

    # Now, run the prediction on the saved temporary file
    try:
        result = predict_from_audio_path(tmp_file_path)
        return JSONResponse(content=result)
    finally:
        # CRITICAL: Always clean up the temporary file
        os.remove(tmp_file_path)