Jeremy Mattathias Mboe
Fix: Correct architecture - input_size=120 concat mode, bidirectional=False
5adf7a9
"""
LSTM Model Architecture for Respiratory Disease Classification
"""
import torch
import torch.nn as nn
from pathlib import Path
# Dummy Config class to handle legacy model loading
# This is needed because the saved model references a Config class
class Config:
"""Dummy Config class for backward compatibility with saved models"""
def __init__(self):
self.BASE_PATH = Path('.')
self.COUGH_PATH = Path('.')
self.VOWEL_PATH = Path('.')
self.SAMPLE_RATE = 16000
self.DURATION = 2.0
self.N_FFT = 2048
self.HOP_LENGTH = 512
self.WIN_LENGTH = 2048
self.N_MFCC = 20
self.N_MELS = 64
self.INPUT_SIZE = 60
self.HIDDEN_SIZE = 128
self.NUM_LAYERS = 2
self.OUTPUT_SIZE = 3
self.DROPOUT = 0.4
self.BIDIRECTIONAL = True
self.BATCH_SIZE = 16
self.LEARNING_RATE = 0.0005
self.COMBINE_MODE = "concat"
class RespiratoryLSTM(nn.Module):
"""LSTM-based model for respiratory disease classification with attention mechanism"""
def __init__(self, input_size=120, hidden_size=128, num_layers=2,
output_size=3, dropout=0.4, bidirectional=True):
"""
Initialize the LSTM model
Args:
input_size: Input feature dimension (60*2=120 for concat mode)
hidden_size: Hidden state dimension
num_layers: Number of LSTM layers
output_size: Number of output classes (3: Healthy, COPD, Asthma)
dropout: Dropout rate
bidirectional: Whether to use bidirectional LSTM
"""
super(RespiratoryLSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
# Input batch normalization
self.batch_norm_input = nn.BatchNorm1d(input_size)
# LSTM layers
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
# Calculate LSTM output size
lstm_output_size = hidden_size * 2 if bidirectional else hidden_size
# Attention mechanism
self.attention = nn.Sequential(
nn.Linear(lstm_output_size, lstm_output_size // 2),
nn.Tanh(),
nn.Linear(lstm_output_size // 2, 1)
)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(lstm_output_size, hidden_size),
nn.ReLU(),
nn.BatchNorm1d(hidden_size),
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.BatchNorm1d(hidden_size // 2),
nn.Dropout(dropout),
nn.Linear(hidden_size // 2, output_size)
)
def forward(self, audio_features, lengths=None):
"""
Forward pass
Args:
audio_features: (batch_size, seq_len, feature_dim)
lengths: (batch_size,) actual lengths of sequences (optional)
Returns:
logits: (batch_size, output_size)
"""
batch_size = audio_features.size(0)
# Apply batch normalization to audio features
# Reshape for batch norm: (batch_size, feature_dim, seq_len)
audio_features_transposed = audio_features.transpose(1, 2)
audio_features_normed = self.batch_norm_input(audio_features_transposed)
audio_features = audio_features_normed.transpose(1, 2)
# LSTM forward pass
if lengths is not None:
# Pack padded sequences for efficient processing
packed_input = nn.utils.rnn.pack_padded_sequence(
audio_features, lengths.cpu(), batch_first=True, enforce_sorted=False
)
packed_output, (hidden, cell) = self.lstm(packed_input)
lstm_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
else:
lstm_output, (hidden, cell) = self.lstm(audio_features)
# Apply attention mechanism
attention_weights = self.attention(lstm_output) # (batch_size, seq_len, 1)
attention_weights = torch.softmax(attention_weights, dim=1)
# Weighted sum using attention
attended_output = torch.sum(lstm_output * attention_weights, dim=1) # (batch_size, lstm_output_size)
# Classification
logits = self.classifier(attended_output)
return logits
def load_lstm_model(model_path, device='cpu'):
"""
Load pre-trained LSTM model
Args:
model_path: Path to the .pth model file
device: Device to load model on ('cpu' or 'cuda')
Returns:
Loaded model in evaluation mode
"""
# Initialize model with ACTUAL architecture from training
# The saved model shows: input_size=120 (concatenated), hidden=128, bidirectional=FALSE
model = RespiratoryLSTM(
input_size=120, # 60*2 for concatenated cough and vowel features
hidden_size=128,
num_layers=2,
output_size=3,
dropout=0.4,
bidirectional=False # Model was trained WITHOUT bidirectional
)
# Load weights - use weights_only=False to handle models with metadata
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
return model
def predict_audio(model, audio_features, device='cpu'):
"""
Make prediction using the LSTM model
Args:
model: Loaded LSTM model
audio_features: Extracted audio features (n_frames, n_features)
device: Device to run inference on
Returns:
prediction: Class prediction (0, 1, or 2)
probabilities: Probability distribution over classes
"""
model.eval()
with torch.no_grad():
# Convert to tensor and add batch dimension
audio_tensor = torch.FloatTensor(audio_features).unsqueeze(0).to(device)
# Get model output
logits = model(audio_tensor)
# Get probabilities
probabilities = torch.softmax(logits, dim=1)
# Get prediction
prediction = torch.argmax(logits, dim=1).item()
# Convert probabilities to numpy
probs_numpy = probabilities.cpu().numpy()[0]
return prediction, probs_numpy