|
|
import os
|
|
|
import numpy as np
|
|
|
import librosa
|
|
|
import librosa.display
|
|
|
import matplotlib.pyplot as plt
|
|
|
import noisereduce as nr
|
|
|
import random
|
|
|
import shutil
|
|
|
|
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
|
from tensorflow.keras.models import Sequential
|
|
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
|
|
|
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
|
|
|
from tensorflow.keras.optimizers import Adam
|
|
|
|
|
|
|
|
|
DATA_SOURCE_PATH = 'data'
|
|
|
SPECTROGRAM_PATH = 'spectrograms_stft_5s_grayscale'
|
|
|
MODEL_SAVE_PATH = 'parkinson_cnn_model_stft_grayscale.h5'
|
|
|
IMG_HEIGHT, IMG_WIDTH = 224, 224
|
|
|
BATCH_SIZE = 32
|
|
|
TARGET_DURATION_S = 5
|
|
|
|
|
|
|
|
|
def augment_audio(y, sr):
|
|
|
y_aug = y.copy()
|
|
|
pitch_steps = random.uniform(-2, 2)
|
|
|
y_aug = librosa.effects.pitch_shift(y_aug, sr=sr, n_steps=pitch_steps)
|
|
|
stretch_rate = random.uniform(0.9, 1.1)
|
|
|
y_aug = librosa.effects.time_stretch(y_aug, rate=stretch_rate)
|
|
|
noise_amp = 0.005 * np.random.uniform() * np.amax(y)
|
|
|
y_aug = y_aug + noise_amp * np.random.normal(size=len(y_aug))
|
|
|
return y_aug
|
|
|
|
|
|
|
|
|
def create_stft_spectrogram(audio_file, save_path, augment=False):
|
|
|
"""
|
|
|
Creates a high-quality GRAYSCALE spectrogram from a standardized 5s audio segment.
|
|
|
"""
|
|
|
try:
|
|
|
y, sr = librosa.load(audio_file, sr=None)
|
|
|
|
|
|
target_samples = TARGET_DURATION_S * sr
|
|
|
|
|
|
if len(y) > target_samples:
|
|
|
start_index = int((len(y) - target_samples) / 2)
|
|
|
y_segment = y[start_index : start_index + target_samples]
|
|
|
else:
|
|
|
y_segment = librosa.util.pad_center(y, size=target_samples)
|
|
|
|
|
|
if augment:
|
|
|
y_segment = augment_audio(y_segment, sr)
|
|
|
|
|
|
y_reduced = nr.reduce_noise(y=y_segment, sr=sr)
|
|
|
|
|
|
N_FFT = 1024
|
|
|
HOP_LENGTH = 256
|
|
|
S_audio = librosa.stft(y_reduced, n_fft=N_FFT, hop_length=HOP_LENGTH)
|
|
|
Y_db = librosa.amplitude_to_db(np.abs(S_audio), ref=np.max)
|
|
|
|
|
|
plt.figure(figsize=(12, 4))
|
|
|
|
|
|
librosa.display.specshow(Y_db, sr=sr, hop_length=HOP_LENGTH, x_axis='time', y_axis='log', cmap='gray_r')
|
|
|
plt.axis('off')
|
|
|
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
|
|
plt.close()
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f" - Error processing {audio_file}: {e}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
def process_all_audio_files():
|
|
|
if os.path.exists(SPECTROGRAM_PATH):
|
|
|
shutil.rmtree(SPECTROGRAM_PATH)
|
|
|
print(f"Starting audio to Grayscale STFT Spectrogram conversion ({TARGET_DURATION_S}s)...")
|
|
|
for split in ['train', 'validation']:
|
|
|
for category in ['parkinson', 'healthy']:
|
|
|
os.makedirs(os.path.join(SPECTROGRAM_PATH, split, category), exist_ok=True)
|
|
|
for category in ['parkinson', 'healthy']:
|
|
|
source_dir = os.path.join(DATA_SOURCE_PATH, category)
|
|
|
all_files = [f for f in os.listdir(source_dir) if f.lower().endswith(('.wav', '.mp3'))]
|
|
|
if not all_files: continue
|
|
|
random.shuffle(all_files)
|
|
|
split_index = int(len(all_files) * 0.8)
|
|
|
train_files, validation_files = all_files[:split_index], all_files[split_index:]
|
|
|
print(f"--- Processing Category: {category} ---")
|
|
|
for filename in train_files:
|
|
|
file_path = os.path.join(source_dir, filename)
|
|
|
base_name = os.path.splitext(filename)[0]
|
|
|
for i in range(3):
|
|
|
save_path = os.path.join(SPECTROGRAM_PATH, 'train', category, f"{base_name}_aug_{i}.png")
|
|
|
create_stft_spectrogram(file_path, save_path, augment=(i > 0))
|
|
|
for filename in validation_files:
|
|
|
file_path = os.path.join(source_dir, filename)
|
|
|
base_name = os.path.splitext(filename)[0]
|
|
|
save_path = os.path.join(SPECTROGRAM_PATH, 'validation', category, f"{base_name}.png")
|
|
|
create_stft_spectrogram(file_path, save_path, augment=False)
|
|
|
print("Spectrogram generation complete.")
|
|
|
|
|
|
|
|
|
def train_cnn_model():
|
|
|
"""
|
|
|
Trains a CNN model optimized for grayscale spectrograms.
|
|
|
"""
|
|
|
if not os.path.exists(SPECTROGRAM_PATH):
|
|
|
print("Spectrograms not found.")
|
|
|
return
|
|
|
|
|
|
train_datagen = ImageDataGenerator(rescale=1./255)
|
|
|
validation_datagen = ImageDataGenerator(rescale=1./255)
|
|
|
|
|
|
|
|
|
train_generator = train_datagen.flow_from_directory(
|
|
|
os.path.join(SPECTROGRAM_PATH, 'train'),
|
|
|
target_size=(IMG_HEIGHT, IMG_WIDTH),
|
|
|
batch_size=BATCH_SIZE,
|
|
|
class_mode='binary',
|
|
|
color_mode='grayscale'
|
|
|
)
|
|
|
validation_generator = validation_datagen.flow_from_directory(
|
|
|
os.path.join(SPECTROGRAM_PATH, 'validation'),
|
|
|
target_size=(IMG_HEIGHT, IMG_WIDTH),
|
|
|
batch_size=BATCH_SIZE,
|
|
|
class_mode='binary',
|
|
|
color_mode='grayscale'
|
|
|
)
|
|
|
|
|
|
if not train_generator.samples > 0:
|
|
|
print("Error: No training images were generated.")
|
|
|
return
|
|
|
|
|
|
|
|
|
model = Sequential([
|
|
|
Conv2D(32, (3, 3), activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH, 1), padding='same'),
|
|
|
BatchNormalization(),
|
|
|
MaxPooling2D((2, 2)),
|
|
|
Conv2D(64, (3, 3), activation='relu', padding='same'),
|
|
|
BatchNormalization(),
|
|
|
MaxPooling2D((2, 2)),
|
|
|
Conv2D(128, (3, 3), activation='relu', padding='same'),
|
|
|
BatchNormalization(),
|
|
|
MaxPooling2D((2, 2)),
|
|
|
Flatten(),
|
|
|
Dense(256, activation='relu'),
|
|
|
BatchNormalization(),
|
|
|
Dropout(0.5),
|
|
|
Dense(128, activation='relu'),
|
|
|
Dropout(0.4),
|
|
|
Dense(1, activation='sigmoid')
|
|
|
])
|
|
|
|
|
|
model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
|
|
|
model.summary()
|
|
|
|
|
|
callbacks_list = [
|
|
|
EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True),
|
|
|
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=7),
|
|
|
ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_accuracy', save_best_only=True, mode='max')
|
|
|
]
|
|
|
|
|
|
model.fit(
|
|
|
train_generator,
|
|
|
epochs=100,
|
|
|
validation_data=validation_generator,
|
|
|
callbacks=callbacks_list
|
|
|
)
|
|
|
print(f"Grayscale model training complete. Best model saved to {MODEL_SAVE_PATH}")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
process_all_audio_files()
|
|
|
train_cnn_model() |