Spaces:
Sleeping
Sleeping
| """Model registry for loading models from MLflow or placeholders.""" | |
| import logging | |
| from typing import Optional | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class ModelRegistry: | |
| """ | |
| Model registry for loading trained models or placeholders. | |
| For MVP: Returns placeholder models that generate mock predictions. | |
| In production: Load real models from MLflow. | |
| """ | |
| def __init__(self, use_placeholder: bool = True): | |
| """ | |
| Initialize model registry. | |
| Args: | |
| use_placeholder: If True, return placeholder models (for MVP) | |
| """ | |
| self.use_placeholder = use_placeholder | |
| logger.info(f"ModelRegistry initialized (placeholder={use_placeholder})") | |
| def load_ensemble(self, version: str = "latest") -> "PlaceholderEnsemble": | |
| """ | |
| Load ensemble model. | |
| Args: | |
| version: Model version | |
| Returns: | |
| Ensemble model (trained HuBERT or placeholder) | |
| """ | |
| if self.use_placeholder: | |
| logger.info("Loading placeholder ensemble model") | |
| return PlaceholderEnsemble() | |
| else: | |
| logger.info("Loading trained HuBERT model") | |
| return TrainedHuBERTEnsemble() | |
| def load_calibration(self, version: str = "latest") -> "PlaceholderCalibration": | |
| """ | |
| Load calibration parameters. | |
| Args: | |
| version: Calibration version | |
| Returns: | |
| Calibration object (placeholder for MVP) | |
| """ | |
| if self.use_placeholder: | |
| logger.info("Loading placeholder calibration") | |
| return PlaceholderCalibration() | |
| else: | |
| logger.warning("MLflow calibration loading not implemented, using placeholder") | |
| return PlaceholderCalibration() | |
| class PlaceholderEnsemble: | |
| """Placeholder ensemble model that returns mock predictions.""" | |
| def __init__(self, seed: int = 42): | |
| """Initialize placeholder model.""" | |
| self.seed = seed | |
| self.version = "placeholder-v1.0" | |
| np.random.seed(seed) | |
| def predict(self, waveform: np.ndarray, spectrogram: np.ndarray, acoustic_features: np.ndarray) -> dict: | |
| """ | |
| Generate mock prediction. | |
| Args: | |
| waveform: Audio waveform (not used in placeholder) | |
| spectrogram: Spectrogram features (not used in placeholder) | |
| acoustic_features: Acoustic features (not used in placeholder) | |
| Returns: | |
| Dictionary with logits and probabilities | |
| """ | |
| # Generate random but realistic-looking predictions | |
| # Bias towards "healthy" (non-dysarthric) for testing | |
| prob_dysarthric = np.random.beta(2, 5) # Beta distribution, mean ~0.29 | |
| logit_healthy = np.log((1 - prob_dysarthric) / (prob_dysarthric + 1e-8)) | |
| logit_dysarthric = np.log(prob_dysarthric / (1 - prob_dysarthric + 1e-8)) | |
| logits = np.array([logit_healthy, logit_dysarthric]) | |
| probs = np.array([1 - prob_dysarthric, prob_dysarthric]) | |
| logger.debug(f"Placeholder prediction: prob_dysarthric={prob_dysarthric:.3f}") | |
| return { | |
| "logits": logits, | |
| "probabilities": probs, | |
| "raw_probability": float(prob_dysarthric), | |
| } | |
| class TrainedHuBERTEnsemble: | |
| """Real trained HuBERT model for dysarthria detection.""" | |
| def __init__(self, checkpoint_path: str = "models/hubert_fast_best.pt"): | |
| """ | |
| Initialize with trained checkpoint. | |
| Args: | |
| checkpoint_path: Path to trained model checkpoint | |
| """ | |
| from training.train_hubert_fast import SimplifiedHuBERTClassifier | |
| self.checkpoint_path = Path(checkpoint_path) | |
| if not self.checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Model checkpoint not found: {checkpoint_path}") | |
| # Detect device | |
| if torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| else: | |
| self.device = torch.device("cpu") | |
| # Load model | |
| logger.info(f"Loading trained HuBERT model from {checkpoint_path} on {self.device}") | |
| self.model = SimplifiedHuBERTClassifier(freeze_base=True).to(self.device) | |
| checkpoint = torch.load(self.checkpoint_path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.eval() | |
| self.version = f"hubert-fast-epoch{checkpoint['epoch']}-auc{checkpoint['val_auc']:.4f}" | |
| logger.info(f"✓ Loaded trained model: {self.version}") | |
| def predict(self, waveform: np.ndarray, spectrogram: np.ndarray, acoustic_features: np.ndarray) -> dict: | |
| """ | |
| Generate prediction using trained model. | |
| Args: | |
| waveform: Audio waveform (1D numpy array) | |
| spectrogram: Spectrogram features (not used by HuBERT) | |
| acoustic_features: Acoustic features (not used by HuBERT) | |
| Returns: | |
| Dictionary with logits and probabilities | |
| """ | |
| # Prepare input (pad or truncate to 10 seconds) | |
| target_length = 16000 * 10 | |
| if len(waveform) > target_length: | |
| waveform = waveform[:target_length] | |
| else: | |
| waveform = np.pad(waveform, (0, target_length - len(waveform))) | |
| # Convert to tensor | |
| waveform_tensor = torch.from_numpy(waveform).float().unsqueeze(0).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| logits = self.model(waveform_tensor) | |
| probs = torch.softmax(logits, dim=1) | |
| # Convert to numpy | |
| logits_np = logits.cpu().numpy()[0] | |
| probs_np = probs.cpu().numpy()[0] | |
| logger.debug(f"Trained model prediction: prob_dysarthric={probs_np[1]:.3f}") | |
| return { | |
| "logits": logits_np, | |
| "probabilities": probs_np, | |
| "raw_probability": float(probs_np[1]), | |
| } | |
| class PlaceholderCalibration: | |
| """Placeholder calibration (identity transform for testing).""" | |
| def transform(self, logits: np.ndarray) -> float: | |
| """ | |
| Apply calibration to logits. | |
| Args: | |
| logits: Model logits [healthy, dysarthric] | |
| Returns: | |
| Calibrated probability of dysarthria | |
| """ | |
| # Simple softmax for placeholder | |
| exp_logits = np.exp(logits - np.max(logits)) | |
| probs = exp_logits / np.sum(exp_logits) | |
| calibrated_prob = float(probs[1]) # Probability of dysarthric class | |
| logger.debug(f"Placeholder calibration: {calibrated_prob:.3f}") | |
| return calibrated_prob | |