File size: 8,402 Bytes
9cb7f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import contextlib
import wave
import librosa
import numpy as np
import pandas as pd
import parselmouth
import soundfile as sf
import webrtcvad
from tensorflow.keras.models import load_model
import joblib
import warnings
import tempfile
import json

# --- Streamlit Imports ---
import streamlit as st
from sklearn.preprocessing import StandardScaler

# --- Configuration ---
TARGET_SR = 16000
MODEL_PATH = "vocal_model.h5"
# We now use the JSON file for the scaler
SCALER_PATH_JSON = "vocal_scaler.json" 
FEATURES_PATH = "feature_names.joblib"

# --- Suppress Warnings ---
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# --- Caching the Models and Scaler ---
# This is a key Streamlit feature. It loads the models only once, making the app fast.
@st.cache_resource
def load_models_and_scaler():
    """Load and cache the model, scaler, and feature names."""
    model = load_model(MODEL_PATH)
    
    # Load scaler from JSON
    with open(SCALER_PATH_JSON, 'r') as f:
        scaler_data = json.load(f)
    
    scaler = StandardScaler()
    scaler.mean_ = np.array(scaler_data['mean_'])
    scaler.scale_ = np.array(scaler_data['scale_'])
    
    feature_names = joblib.load(FEATURES_PATH)
    
    return model, scaler, feature_names

# --- Feature Extraction Functions (Your original functions) ---
# ... (Copy ALL your feature extraction functions here, exactly as they were) ...
def preprocess_audio(input_path, target_sr=TARGET_SR):
    try:
        data, sr = librosa.load(input_path, sr=None, mono=False)
        if data.ndim > 1: data = data.mean(axis=0)
        if sr != target_sr: data = librosa.resample(data, orig_sr=sr, target_sr=target_sr)
        base, ext = os.path.splitext(input_path)
        output_path = f"{base}_processed_for_prediction.wav"
        sf.write(output_path, data, target_sr, subtype='PCM_16')
        return output_path
    except Exception as e:
        st.error(f"Error preprocessing audio: {e}")
        return None

def extract_features(file_path):
    try:
        y, sr = librosa.load(file_path, sr=None)
        duration = librosa.get_duration(y=y, sr=sr)
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
        mfcc_means = np.mean(mfccs, axis=1)

        snd = parselmouth.Sound(file_path)
        pitch = snd.to_pitch()
        pitch_values = pitch.selected_array['frequency']
        pitch_values = pitch_values[pitch_values != 0]

        pitch_mean = np.mean(pitch_values) if len(pitch_values) > 0 else 0
        pitch_std = np.std(pitch_values) if len(pitch_values) > 0 else 0

        point_process = parselmouth.praat.call(snd, "To PointProcess (periodic, cc)", 75, 500)
        jitter_local = parselmouth.praat.call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
        shimmer_local = parselmouth.praat.call([snd, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)

        def read_wave(path):
            with contextlib.closing(wave.open(path, 'rb')) as wf:
                pcm_data, sample_rate = wf.readframes(wf.getnframes()), wf.getframerate()
                return pcm_data, sample_rate
        
        def frame_generator(frame_duration_ms, audio, sample_rate):
            n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
            offset = 0
            while offset + n < len(audio):
                yield audio[offset:offset + n]
                offset += n
        
        vad = webrtcvad.Vad(1)
        audio, sample_rate = read_wave(file_path)
        frames = list(frame_generator(30, audio, sample_rate))
        voiced_seconds = 0
        num_segments = 0
        if frames:
            for frame in frames:
                if vad.is_speech(frame, sample_rate):
                    voiced_seconds += 0.03 # 30ms frame
                    num_segments +=1

        silence_ratio = max(0, (duration - voiced_seconds) / duration) if duration > 0 else 0
        speaking_rate = num_segments / duration if duration > 0 else 0

        features = {
            'Duration': duration,
            'Pitch_Mean': pitch_mean,
            'Pitch_Std': pitch_std,
            'Jitter': jitter_local,
            'Shimmer': shimmer_local,
            'Speaking_Rate': speaking_rate,
            'Silence_Ratio': silence_ratio,
        }
        for idx, val in enumerate(mfcc_means):
            features[f'MFCC_{idx+1}'] = val
            
        return features
    except Exception as e:
        st.error(f"Error extracting features: {e}")
        return None


# --- Main Prediction Logic (Refactored for Streamlit) ---
def predict(audio_file_path, model, scaler, feature_names):
    """Takes an audio file path and returns the prediction results."""
    processed_path = preprocess_audio(audio_file_path)
    if not processed_path:
        return None, None

    features_dict = extract_features(processed_path)
    os.remove(processed_path) # Clean up the processed file
    if not features_dict:
        return None, None
        
    # Convert to DataFrame and ensure correct column order
    feature_df = pd.DataFrame([features_dict])
    feature_df = feature_df[feature_names]

    # Scale the features
    scaled_features = scaler.transform(feature_df)

    # Make prediction
    prediction_prob = model.predict(scaled_features, verbose=0)[0][0]
    
    return prediction_prob, features_dict


# --- Streamlit App Interface ---

st.set_page_config(page_title="Parkinson's Voice Detector", page_icon="🩺", layout="centered")

st.title("🩺 Parkinson's Disease Detection from Voice")
st.markdown("""

This app uses a machine learning model to predict the likelihood of Parkinson's disease based on vocal features. 

Upload a short voice recording (e.g., of someone saying "ahhh" for a few seconds) to get a prediction.



**Disclaimer:** This is a demonstration tool and not a substitute for professional medical advice.

""")

# Load models
try:
    model, scaler, feature_names = load_models_and_scaler()
    st.sidebar.success("Model and components loaded successfully!")
except Exception as e:
    st.error(f"Error loading model components: {e}")
    st.stop() # Stop the app if models can't be loaded

# File Uploader
uploaded_file = st.file_uploader(
    "Choose a voice recording...", 
    type=["wav", "mp3", "ogg", "flac"]
)

if uploaded_file is not None:
    st.audio(uploaded_file, format='audio/wav')

    # When the user clicks the button, start prediction
    if st.button("Analyze Audio", type="primary"):
        # Save the uploaded file to a temporary location
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
            tmp_file.write(uploaded_file.getvalue())
            tmp_file_path = tmp_file.name

        with st.spinner('Analyzing audio... This may take a moment.'):
            try:
                prediction_prob, features = predict(tmp_file_path, model, scaler, feature_names)
                
                if prediction_prob is not None:
                    # Display results
                    st.subheader("Analysis Result")
                    is_parkinsons = prediction_prob > 0.5
                    
                    if is_parkinsons:
                        st.warning(f"**Parkinson's Detected** (Confidence: {prediction_prob:.2%})")
                    else:
                        st.success(f"**Healthy** (Confidence of being healthy: {(1-prediction_prob):.2%})")

                    # Display confidence as a progress bar
                    st.progress(prediction_prob)
                    st.markdown(f"The model's confidence score for the presence of Parkinson's is **{prediction_prob:.2%}**.")
                    
                    # Show extracted features in an expander
                    with st.expander("View Extracted Vocal Features"):
                        st.json(features)
                else:
                    st.error("Could not process the audio file. Please try a different file.")

            except Exception as e:
                st.error(f"An unexpected error occurred during analysis: {e}")
            finally:
                # Clean up the temporary file
                os.remove(tmp_file_path)