StrokeMitra-API / src /models /ensemble.py
DhruvB1906's picture
Upload folder using huggingface_hub
4e9a3bc verified
"""Ensemble model combining HuBERT-SALR and CNN-BiLSTM branches."""
import logging
import numpy as np
logger = logging.getLogger(__name__)
class EnsembleModel:
"""
Ensemble model for dysarthria detection.
For MVP: Uses placeholder branch models.
In production: Uses trained HuBERT-SALR and CNN-BiLSTM models.
"""
def __init__(self, alpha: float = 0.6):
"""
Initialize ensemble.
Args:
alpha: Weight for HuBERT branch (1-alpha for CNN branch)
"""
self.alpha = alpha
self.version = "ensemble-v1.0-placeholder"
logger.info(f"EnsembleModel initialized (alpha={alpha})")
def predict(
self,
waveform: np.ndarray,
spectrogram: np.ndarray,
acoustic_features: np.ndarray,
) -> dict:
"""
Run ensemble prediction.
Args:
waveform: Audio waveform for HuBERT branch
spectrogram: Spectrogram for CNN branch
acoustic_features: Acoustic features for fusion
Returns:
Dictionary with logits and probabilities
"""
logger.debug("Running ensemble prediction (placeholder)")
# Placeholder: Generate mock predictions from both branches
hubert_logits = self._mock_hubert_branch(waveform)
cnn_logits = self._mock_cnn_branch(spectrogram, acoustic_features)
# Ensemble: weighted average of logits
ensemble_logits = self.alpha * hubert_logits + (1 - self.alpha) * cnn_logits
# Convert to probabilities
exp_logits = np.exp(ensemble_logits - np.max(ensemble_logits))
probs = exp_logits / np.sum(exp_logits)
raw_probability = float(probs[1]) # Probability of dysarthric class
logger.info(f"Ensemble prediction: prob_dysarthric={raw_probability:.3f}")
return {
"logits": ensemble_logits,
"probabilities": probs,
"raw_probability": raw_probability,
"hubert_logits": hubert_logits,
"cnn_logits": cnn_logits,
"alpha": self.alpha,
}
def _mock_hubert_branch(self, waveform: np.ndarray) -> np.ndarray:
"""Mock HuBERT-SALR branch prediction."""
# Generate somewhat realistic logits
prob = np.random.beta(2, 5) # Bias towards healthy
logit_healthy = np.log((1 - prob) / (prob + 1e-8))
logit_dysarthric = np.log(prob / (1 - prob + 1e-8))
return np.array([logit_healthy, logit_dysarthric])
def _mock_cnn_branch(self, spectrogram: np.ndarray, acoustic_features: np.ndarray) -> np.ndarray:
"""Mock CNN-BiLSTM-Transformer branch prediction."""
# Generate somewhat realistic logits (slightly different from HuBERT)
prob = np.random.beta(2.5, 5.5) # Slightly different distribution
logit_healthy = np.log((1 - prob) / (prob + 1e-8))
logit_dysarthric = np.log(prob / (1 - prob + 1e-8))
return np.array([logit_healthy, logit_dysarthric])