Spaces:
Build error
Build error
Jeremy Mattathias Mboe
Fix: Correct architecture - input_size=120 concat mode, bidirectional=False
5adf7a9 | """ | |
| LSTM Model Architecture for Respiratory Disease Classification | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from pathlib import Path | |
| # Dummy Config class to handle legacy model loading | |
| # This is needed because the saved model references a Config class | |
| class Config: | |
| """Dummy Config class for backward compatibility with saved models""" | |
| def __init__(self): | |
| self.BASE_PATH = Path('.') | |
| self.COUGH_PATH = Path('.') | |
| self.VOWEL_PATH = Path('.') | |
| self.SAMPLE_RATE = 16000 | |
| self.DURATION = 2.0 | |
| self.N_FFT = 2048 | |
| self.HOP_LENGTH = 512 | |
| self.WIN_LENGTH = 2048 | |
| self.N_MFCC = 20 | |
| self.N_MELS = 64 | |
| self.INPUT_SIZE = 60 | |
| self.HIDDEN_SIZE = 128 | |
| self.NUM_LAYERS = 2 | |
| self.OUTPUT_SIZE = 3 | |
| self.DROPOUT = 0.4 | |
| self.BIDIRECTIONAL = True | |
| self.BATCH_SIZE = 16 | |
| self.LEARNING_RATE = 0.0005 | |
| self.COMBINE_MODE = "concat" | |
| class RespiratoryLSTM(nn.Module): | |
| """LSTM-based model for respiratory disease classification with attention mechanism""" | |
| def __init__(self, input_size=120, hidden_size=128, num_layers=2, | |
| output_size=3, dropout=0.4, bidirectional=True): | |
| """ | |
| Initialize the LSTM model | |
| Args: | |
| input_size: Input feature dimension (60*2=120 for concat mode) | |
| hidden_size: Hidden state dimension | |
| num_layers: Number of LSTM layers | |
| output_size: Number of output classes (3: Healthy, COPD, Asthma) | |
| dropout: Dropout rate | |
| bidirectional: Whether to use bidirectional LSTM | |
| """ | |
| super(RespiratoryLSTM, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.bidirectional = bidirectional | |
| # Input batch normalization | |
| self.batch_norm_input = nn.BatchNorm1d(input_size) | |
| # LSTM layers | |
| self.lstm = nn.LSTM( | |
| input_size=input_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout if num_layers > 1 else 0, | |
| bidirectional=bidirectional | |
| ) | |
| # Calculate LSTM output size | |
| lstm_output_size = hidden_size * 2 if bidirectional else hidden_size | |
| # Attention mechanism | |
| self.attention = nn.Sequential( | |
| nn.Linear(lstm_output_size, lstm_output_size // 2), | |
| nn.Tanh(), | |
| nn.Linear(lstm_output_size // 2, 1) | |
| ) | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Linear(lstm_output_size, hidden_size), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(hidden_size), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(hidden_size // 2), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size // 2, output_size) | |
| ) | |
| def forward(self, audio_features, lengths=None): | |
| """ | |
| Forward pass | |
| Args: | |
| audio_features: (batch_size, seq_len, feature_dim) | |
| lengths: (batch_size,) actual lengths of sequences (optional) | |
| Returns: | |
| logits: (batch_size, output_size) | |
| """ | |
| batch_size = audio_features.size(0) | |
| # Apply batch normalization to audio features | |
| # Reshape for batch norm: (batch_size, feature_dim, seq_len) | |
| audio_features_transposed = audio_features.transpose(1, 2) | |
| audio_features_normed = self.batch_norm_input(audio_features_transposed) | |
| audio_features = audio_features_normed.transpose(1, 2) | |
| # LSTM forward pass | |
| if lengths is not None: | |
| # Pack padded sequences for efficient processing | |
| packed_input = nn.utils.rnn.pack_padded_sequence( | |
| audio_features, lengths.cpu(), batch_first=True, enforce_sorted=False | |
| ) | |
| packed_output, (hidden, cell) = self.lstm(packed_input) | |
| lstm_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True) | |
| else: | |
| lstm_output, (hidden, cell) = self.lstm(audio_features) | |
| # Apply attention mechanism | |
| attention_weights = self.attention(lstm_output) # (batch_size, seq_len, 1) | |
| attention_weights = torch.softmax(attention_weights, dim=1) | |
| # Weighted sum using attention | |
| attended_output = torch.sum(lstm_output * attention_weights, dim=1) # (batch_size, lstm_output_size) | |
| # Classification | |
| logits = self.classifier(attended_output) | |
| return logits | |
| def load_lstm_model(model_path, device='cpu'): | |
| """ | |
| Load pre-trained LSTM model | |
| Args: | |
| model_path: Path to the .pth model file | |
| device: Device to load model on ('cpu' or 'cuda') | |
| Returns: | |
| Loaded model in evaluation mode | |
| """ | |
| # Initialize model with ACTUAL architecture from training | |
| # The saved model shows: input_size=120 (concatenated), hidden=128, bidirectional=FALSE | |
| model = RespiratoryLSTM( | |
| input_size=120, # 60*2 for concatenated cough and vowel features | |
| hidden_size=128, | |
| num_layers=2, | |
| output_size=3, | |
| dropout=0.4, | |
| bidirectional=False # Model was trained WITHOUT bidirectional | |
| ) | |
| # Load weights - use weights_only=False to handle models with metadata | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict): | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| elif 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def predict_audio(model, audio_features, device='cpu'): | |
| """ | |
| Make prediction using the LSTM model | |
| Args: | |
| model: Loaded LSTM model | |
| audio_features: Extracted audio features (n_frames, n_features) | |
| device: Device to run inference on | |
| Returns: | |
| prediction: Class prediction (0, 1, or 2) | |
| probabilities: Probability distribution over classes | |
| """ | |
| model.eval() | |
| with torch.no_grad(): | |
| # Convert to tensor and add batch dimension | |
| audio_tensor = torch.FloatTensor(audio_features).unsqueeze(0).to(device) | |
| # Get model output | |
| logits = model(audio_tensor) | |
| # Get probabilities | |
| probabilities = torch.softmax(logits, dim=1) | |
| # Get prediction | |
| prediction = torch.argmax(logits, dim=1).item() | |
| # Convert probabilities to numpy | |
| probs_numpy = probabilities.cpu().numpy()[0] | |
| return prediction, probs_numpy | |