| |
| """ |
| Feature extraction for IndicWav2Vec Hindi ASR |
| |
| This module provides feature extraction capabilities using the IndicWav2Vec Hindi model. |
| Focused on ASR transcription features rather than hybrid acoustic+linguistic features. |
| """ |
| import torch |
| import numpy as np |
| import logging |
| from typing import Dict, Any, Tuple, Optional |
| from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ASRFeatureExtractor: |
| """ |
| Feature extractor using IndicWav2Vec Hindi for Automatic Speech Recognition. |
| |
| This extractor focuses on: |
| - Audio feature extraction via IndicWav2Vec |
| - Transcription confidence scores |
| - Frame-level predictions and logits |
| - Word-level alignments (estimated) |
| |
| Model: ai4bharat/indicwav2vec-hindi |
| """ |
| |
| def __init__(self, model: Wav2Vec2ForCTC, processor: AutoProcessor, device: str = "cpu"): |
| """ |
| Initialize the ASR feature extractor. |
| |
| Args: |
| model: Pre-loaded IndicWav2Vec Hindi model |
| processor: Pre-loaded processor for the model |
| device: Device to run inference on ('cpu' or 'cuda') |
| """ |
| self.model = model |
| self.processor = processor |
| self.device = device |
| self.model.eval() |
| logger.info(f"✅ ASRFeatureExtractor initialized on {device}") |
| |
| def extract_audio_features(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]: |
| """ |
| Extract features from audio using IndicWav2Vec Hindi. |
| |
| Args: |
| audio: Audio waveform as numpy array |
| sample_rate: Sample rate of the audio (default: 16000) |
| |
| Returns: |
| Dictionary containing: |
| - input_values: Processed audio features |
| - attention_mask: Attention mask (if available) |
| """ |
| try: |
| |
| inputs = self.processor( |
| audio, |
| sampling_rate=sample_rate, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| return { |
| 'input_values': inputs.input_values, |
| 'attention_mask': inputs.get('attention_mask', None) |
| } |
| except Exception as e: |
| logger.error(f"❌ Error extracting audio features: {e}") |
| raise |
| |
| def get_transcription_features( |
| self, |
| audio: np.ndarray, |
| sample_rate: int = 16000 |
| ) -> Dict[str, Any]: |
| """ |
| Get transcription features including logits, predictions, and confidence. |
| |
| Args: |
| audio: Audio waveform as numpy array |
| sample_rate: Sample rate of the audio (default: 16000) |
| |
| Returns: |
| Dictionary containing: |
| - transcript: Transcribed text |
| - logits: Model logits (raw predictions) |
| - predicted_ids: Predicted token IDs |
| - probabilities: Softmax probabilities |
| - confidence: Average confidence score |
| - frame_confidence: Per-frame confidence scores |
| """ |
| try: |
| |
| inputs = self.processor( |
| audio, |
| sampling_rate=sample_rate, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| logits = outputs.logits |
| predicted_ids = torch.argmax(logits, dim=-1) |
| |
| |
| probs = torch.softmax(logits, dim=-1) |
| max_probs = torch.max(probs, dim=-1)[0] |
| frame_confidence = max_probs[0].cpu().numpy() |
| avg_confidence = float(torch.mean(max_probs).item()) |
| |
| |
| transcript = "" |
| try: |
| if hasattr(self.processor, 'tokenizer'): |
| transcript = self.processor.tokenizer.decode( |
| predicted_ids[0], |
| skip_special_tokens=True |
| ) |
| elif hasattr(self.processor, 'batch_decode'): |
| transcript = self.processor.batch_decode(predicted_ids)[0] |
| |
| |
| if transcript: |
| transcript = transcript.strip() |
| transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip() |
| transcript = ' '.join(transcript.split()) |
| except Exception as e: |
| logger.warning(f"⚠️ Decode error: {e}") |
| transcript = "" |
| |
| return { |
| 'transcript': transcript, |
| 'logits': logits.cpu().numpy(), |
| 'predicted_ids': predicted_ids.cpu().numpy(), |
| 'probabilities': probs.cpu().numpy(), |
| 'confidence': avg_confidence, |
| 'frame_confidence': frame_confidence, |
| 'num_frames': logits.shape[1] |
| } |
| except Exception as e: |
| logger.error(f"❌ Error getting transcription features: {e}") |
| raise |
| |
| def get_word_level_features( |
| self, |
| audio: np.ndarray, |
| sample_rate: int = 16000 |
| ) -> Dict[str, Any]: |
| """ |
| Get word-level features including timestamps and confidence. |
| |
| Args: |
| audio: Audio waveform as numpy array |
| sample_rate: Sample rate of the audio (default: 16000) |
| |
| Returns: |
| Dictionary containing: |
| - words: List of words |
| - word_timestamps: List of (start, end) timestamps for each word |
| - word_confidence: Confidence score for each word |
| """ |
| try: |
| |
| features = self.get_transcription_features(audio, sample_rate) |
| transcript = features['transcript'] |
| frame_confidence = features['frame_confidence'] |
| num_frames = features['num_frames'] |
| |
| |
| words = transcript.split() if transcript else [] |
| audio_duration = len(audio) / sample_rate |
| time_per_word = audio_duration / max(len(words), 1) if words else 0 |
| |
| word_timestamps = [] |
| word_confidence = [] |
| |
| for i, word in enumerate(words): |
| start_time = i * time_per_word |
| end_time = (i + 1) * time_per_word |
| |
| |
| start_frame = int((start_time / audio_duration) * num_frames) |
| end_frame = int((end_time / audio_duration) * num_frames) |
| word_conf = float(np.mean(frame_confidence[start_frame:end_frame])) if end_frame > start_frame else 0.5 |
| |
| word_timestamps.append({ |
| 'word': word, |
| 'start': start_time, |
| 'end': end_time |
| }) |
| word_confidence.append(word_conf) |
| |
| return { |
| 'words': words, |
| 'word_timestamps': word_timestamps, |
| 'word_confidence': word_confidence, |
| 'transcript': transcript |
| } |
| except Exception as e: |
| logger.error(f"❌ Error getting word-level features: {e}") |
| raise |
|
|
|
|