Spaces:
Build error
Build error
File size: 7,067 Bytes
cab8122 7ebe942 cab8122 615e67b 5adf7a9 cab8122 5adf7a9 cab8122 615e67b cab8122 7ebe942 cab8122 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """
LSTM Model Architecture for Respiratory Disease Classification
"""
import torch
import torch.nn as nn
from pathlib import Path
# Dummy Config class to handle legacy model loading
# This is needed because the saved model references a Config class
class Config:
"""Dummy Config class for backward compatibility with saved models"""
def __init__(self):
self.BASE_PATH = Path('.')
self.COUGH_PATH = Path('.')
self.VOWEL_PATH = Path('.')
self.SAMPLE_RATE = 16000
self.DURATION = 2.0
self.N_FFT = 2048
self.HOP_LENGTH = 512
self.WIN_LENGTH = 2048
self.N_MFCC = 20
self.N_MELS = 64
self.INPUT_SIZE = 60
self.HIDDEN_SIZE = 128
self.NUM_LAYERS = 2
self.OUTPUT_SIZE = 3
self.DROPOUT = 0.4
self.BIDIRECTIONAL = True
self.BATCH_SIZE = 16
self.LEARNING_RATE = 0.0005
self.COMBINE_MODE = "concat"
class RespiratoryLSTM(nn.Module):
"""LSTM-based model for respiratory disease classification with attention mechanism"""
def __init__(self, input_size=120, hidden_size=128, num_layers=2,
output_size=3, dropout=0.4, bidirectional=True):
"""
Initialize the LSTM model
Args:
input_size: Input feature dimension (60*2=120 for concat mode)
hidden_size: Hidden state dimension
num_layers: Number of LSTM layers
output_size: Number of output classes (3: Healthy, COPD, Asthma)
dropout: Dropout rate
bidirectional: Whether to use bidirectional LSTM
"""
super(RespiratoryLSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
# Input batch normalization
self.batch_norm_input = nn.BatchNorm1d(input_size)
# LSTM layers
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
# Calculate LSTM output size
lstm_output_size = hidden_size * 2 if bidirectional else hidden_size
# Attention mechanism
self.attention = nn.Sequential(
nn.Linear(lstm_output_size, lstm_output_size // 2),
nn.Tanh(),
nn.Linear(lstm_output_size // 2, 1)
)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(lstm_output_size, hidden_size),
nn.ReLU(),
nn.BatchNorm1d(hidden_size),
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.BatchNorm1d(hidden_size // 2),
nn.Dropout(dropout),
nn.Linear(hidden_size // 2, output_size)
)
def forward(self, audio_features, lengths=None):
"""
Forward pass
Args:
audio_features: (batch_size, seq_len, feature_dim)
lengths: (batch_size,) actual lengths of sequences (optional)
Returns:
logits: (batch_size, output_size)
"""
batch_size = audio_features.size(0)
# Apply batch normalization to audio features
# Reshape for batch norm: (batch_size, feature_dim, seq_len)
audio_features_transposed = audio_features.transpose(1, 2)
audio_features_normed = self.batch_norm_input(audio_features_transposed)
audio_features = audio_features_normed.transpose(1, 2)
# LSTM forward pass
if lengths is not None:
# Pack padded sequences for efficient processing
packed_input = nn.utils.rnn.pack_padded_sequence(
audio_features, lengths.cpu(), batch_first=True, enforce_sorted=False
)
packed_output, (hidden, cell) = self.lstm(packed_input)
lstm_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
else:
lstm_output, (hidden, cell) = self.lstm(audio_features)
# Apply attention mechanism
attention_weights = self.attention(lstm_output) # (batch_size, seq_len, 1)
attention_weights = torch.softmax(attention_weights, dim=1)
# Weighted sum using attention
attended_output = torch.sum(lstm_output * attention_weights, dim=1) # (batch_size, lstm_output_size)
# Classification
logits = self.classifier(attended_output)
return logits
def load_lstm_model(model_path, device='cpu'):
"""
Load pre-trained LSTM model
Args:
model_path: Path to the .pth model file
device: Device to load model on ('cpu' or 'cuda')
Returns:
Loaded model in evaluation mode
"""
# Initialize model with ACTUAL architecture from training
# The saved model shows: input_size=120 (concatenated), hidden=128, bidirectional=FALSE
model = RespiratoryLSTM(
input_size=120, # 60*2 for concatenated cough and vowel features
hidden_size=128,
num_layers=2,
output_size=3,
dropout=0.4,
bidirectional=False # Model was trained WITHOUT bidirectional
)
# Load weights - use weights_only=False to handle models with metadata
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
return model
def predict_audio(model, audio_features, device='cpu'):
"""
Make prediction using the LSTM model
Args:
model: Loaded LSTM model
audio_features: Extracted audio features (n_frames, n_features)
device: Device to run inference on
Returns:
prediction: Class prediction (0, 1, or 2)
probabilities: Probability distribution over classes
"""
model.eval()
with torch.no_grad():
# Convert to tensor and add batch dimension
audio_tensor = torch.FloatTensor(audio_features).unsqueeze(0).to(device)
# Get model output
logits = model(audio_tensor)
# Get probabilities
probabilities = torch.softmax(logits, dim=1)
# Get prediction
prediction = torch.argmax(logits, dim=1).item()
# Convert probabilities to numpy
probs_numpy = probabilities.cpu().numpy()[0]
return prediction, probs_numpy
|