Spaces:
Runtime error
Runtime error
| # diagnosis/ai_engine/features.py | |
| """ | |
| Feature extraction for IndicWav2Vec Hindi ASR | |
| This module provides feature extraction capabilities using the IndicWav2Vec Hindi model. | |
| Focused on ASR transcription features rather than hybrid acoustic+linguistic features. | |
| """ | |
| import torch | |
| import numpy as np | |
| import logging | |
| from typing import Dict, Any, Tuple, Optional | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor | |
| logger = logging.getLogger(__name__) | |
| class ASRFeatureExtractor: | |
| """ | |
| Feature extractor using IndicWav2Vec Hindi for Automatic Speech Recognition. | |
| This extractor focuses on: | |
| - Audio feature extraction via IndicWav2Vec | |
| - Transcription confidence scores | |
| - Frame-level predictions and logits | |
| - Word-level alignments (estimated) | |
| Model: ai4bharat/indicwav2vec-hindi | |
| """ | |
| def __init__(self, model: Wav2Vec2ForCTC, processor: AutoProcessor, device: str = "cpu"): | |
| """ | |
| Initialize the ASR feature extractor. | |
| Args: | |
| model: Pre-loaded IndicWav2Vec Hindi model | |
| processor: Pre-loaded processor for the model | |
| device: Device to run inference on ('cpu' or 'cuda') | |
| """ | |
| self.model = model | |
| self.processor = processor | |
| self.device = device | |
| self.model.eval() | |
| logger.info(f"✅ ASRFeatureExtractor initialized on {device}") | |
| def extract_audio_features(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]: | |
| """ | |
| Extract features from audio using IndicWav2Vec Hindi. | |
| Args: | |
| audio: Audio waveform as numpy array | |
| sample_rate: Sample rate of the audio (default: 16000) | |
| Returns: | |
| Dictionary containing: | |
| - input_values: Processed audio features | |
| - attention_mask: Attention mask (if available) | |
| """ | |
| try: | |
| # Process audio through the processor | |
| inputs = self.processor( | |
| audio, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| return { | |
| 'input_values': inputs.input_values, | |
| 'attention_mask': inputs.get('attention_mask', None) | |
| } | |
| except Exception as e: | |
| logger.error(f"❌ Error extracting audio features: {e}") | |
| raise | |
| def get_transcription_features( | |
| self, | |
| audio: np.ndarray, | |
| sample_rate: int = 16000 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get transcription features including logits, predictions, and confidence. | |
| Args: | |
| audio: Audio waveform as numpy array | |
| sample_rate: Sample rate of the audio (default: 16000) | |
| Returns: | |
| Dictionary containing: | |
| - transcript: Transcribed text | |
| - logits: Model logits (raw predictions) | |
| - predicted_ids: Predicted token IDs | |
| - probabilities: Softmax probabilities | |
| - confidence: Average confidence score | |
| - frame_confidence: Per-frame confidence scores | |
| """ | |
| try: | |
| # Process audio | |
| inputs = self.processor( | |
| audio, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Get model predictions | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| # Calculate probabilities and confidence | |
| probs = torch.softmax(logits, dim=-1) | |
| max_probs = torch.max(probs, dim=-1)[0] # Get max probability per frame | |
| frame_confidence = max_probs[0].cpu().numpy() | |
| avg_confidence = float(torch.mean(max_probs).item()) | |
| # Decode transcript | |
| transcript = "" | |
| try: | |
| if hasattr(self.processor, 'tokenizer'): | |
| transcript = self.processor.tokenizer.decode( | |
| predicted_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| elif hasattr(self.processor, 'batch_decode'): | |
| transcript = self.processor.batch_decode(predicted_ids)[0] | |
| # Clean up transcript | |
| if transcript: | |
| transcript = transcript.strip() | |
| transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip() | |
| transcript = ' '.join(transcript.split()) | |
| except Exception as e: | |
| logger.warning(f"⚠️ Decode error: {e}") | |
| transcript = "" | |
| return { | |
| 'transcript': transcript, | |
| 'logits': logits.cpu().numpy(), | |
| 'predicted_ids': predicted_ids.cpu().numpy(), | |
| 'probabilities': probs.cpu().numpy(), | |
| 'confidence': avg_confidence, | |
| 'frame_confidence': frame_confidence, | |
| 'num_frames': logits.shape[1] | |
| } | |
| except Exception as e: | |
| logger.error(f"❌ Error getting transcription features: {e}") | |
| raise | |
| def get_word_level_features( | |
| self, | |
| audio: np.ndarray, | |
| sample_rate: int = 16000 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get word-level features including timestamps and confidence. | |
| Args: | |
| audio: Audio waveform as numpy array | |
| sample_rate: Sample rate of the audio (default: 16000) | |
| Returns: | |
| Dictionary containing: | |
| - words: List of words | |
| - word_timestamps: List of (start, end) timestamps for each word | |
| - word_confidence: Confidence score for each word | |
| """ | |
| try: | |
| # Get transcription features | |
| features = self.get_transcription_features(audio, sample_rate) | |
| transcript = features['transcript'] | |
| frame_confidence = features['frame_confidence'] | |
| num_frames = features['num_frames'] | |
| # Estimate word-level timestamps (simplified) | |
| words = transcript.split() if transcript else [] | |
| audio_duration = len(audio) / sample_rate | |
| time_per_word = audio_duration / max(len(words), 1) if words else 0 | |
| word_timestamps = [] | |
| word_confidence = [] | |
| for i, word in enumerate(words): | |
| start_time = i * time_per_word | |
| end_time = (i + 1) * time_per_word | |
| # Estimate confidence for this word (average of corresponding frames) | |
| start_frame = int((start_time / audio_duration) * num_frames) | |
| end_frame = int((end_time / audio_duration) * num_frames) | |
| word_conf = float(np.mean(frame_confidence[start_frame:end_frame])) if end_frame > start_frame else 0.5 | |
| word_timestamps.append({ | |
| 'word': word, | |
| 'start': start_time, | |
| 'end': end_time | |
| }) | |
| word_confidence.append(word_conf) | |
| return { | |
| 'words': words, | |
| 'word_timestamps': word_timestamps, | |
| 'word_confidence': word_confidence, | |
| 'transcript': transcript | |
| } | |
| except Exception as e: | |
| logger.error(f"❌ Error getting word-level features: {e}") | |
| raise | |