| # import sounddevice as sd | |
| # import scipy.io.wavfile as wav | |
| # import nemo.collections.asr as nemo_asr | |
| # import torch | |
| # import numpy as np | |
| # from typing import List, Tuple | |
| # # ===== SETTINGS ===== | |
| # SAMPLE_RATE = 16000 | |
| # DURATION = 10 # seconds | |
| # OUTPUT_FILE = "arabic_recording.wav" | |
| # class RepetitionAwareTranscriber: | |
| # def __init__(self, model_path: str): | |
| # """Initialize ASR model with repetition-aware configuration""" | |
| # print("📥 Loading Arabic ASR model...") | |
| # self.asr_model = nemo_asr.models.EncDecCTCModel.restore_from(model_path) | |
| # self._configure_decoding() | |
| # def _configure_decoding(self): | |
| # """Configure advanced decoding strategy""" | |
| # decoding_cfg = self.asr_model.cfg.decoding | |
| # # Use beam search for better sequence modeling | |
| # decoding_cfg.strategy = "beam" | |
| # decoding_cfg.beam.beam_size = 128 # Larger beam for more candidates | |
| # decoding_cfg.beam.return_best_hypothesis = False # Get multiple hypotheses | |
| # # Language model parameters (if available) | |
| # if hasattr(decoding_cfg.beam, 'beam_alpha'): | |
| # decoding_cfg.beam.beam_alpha = 0.3 # LM weight (lower = less LM influence) | |
| # if hasattr(decoding_cfg.beam, 'beam_beta'): | |
| # decoding_cfg.beam.beam_beta = 0.5 # Word insertion bonus | |
| # self.asr_model.change_decoding_strategy(decoding_cfg) | |
| # def transcribe_with_logprobs(self, audio_file: str, temperature: float = 1.0): | |
| # """ | |
| # Transcribe with log probabilities and temperature scaling | |
| # Args: | |
| # audio_file: Path to audio file | |
| # temperature: Controls randomness (lower = more conservative, higher = more diverse) | |
| # 0.5 = more deterministic | |
| # 1.0 = standard | |
| # 1.5 = more exploratory | |
| # """ | |
| # print(f"🔍 Transcribing with temperature={temperature}...") | |
| # # Update temperature in decoding config | |
| # if hasattr(self.asr_model.cfg.decoding, 'temperature'): | |
| # self.asr_model.cfg.decoding.temperature = temperature | |
| # if hasattr(self.asr_model.cfg.decoding.beam, 'softmax_temperature'): | |
| # self.asr_model.cfg.decoding.beam.softmax_temperature = temperature | |
| # self.asr_model.change_decoding_strategy(self.asr_model.cfg.decoding) | |
| # # Get multiple hypotheses with their scores | |
| # hypotheses = self.asr_model.transcribe( | |
| # [audio_file], | |
| # batch_size=1, | |
| # return_hypotheses=True, | |
| # num_workers=0 | |
| # ) | |
| # # Handle different return types | |
| # if isinstance(hypotheses, list) and len(hypotheses) > 0: | |
| # hyp = hypotheses[0] | |
| # # Check if it's a Hypothesis object or a list | |
| # if isinstance(hyp, list): | |
| # # It's already a list of transcriptions | |
| # best_text = hyp[0] if len(hyp) > 0 else "" | |
| # print(f"\n📊 Top hypothesis: {best_text}") | |
| # return best_text | |
| # elif hasattr(hyp, 'text'): | |
| # # It's a Hypothesis object | |
| # text = hyp.text | |
| # # Check for nbest hypotheses | |
| # if hasattr(hyp, 'nbest') and len(hyp.nbest) > 1: | |
| # print(f"\n📊 Top {min(5, len(hyp.nbest))} hypotheses:") | |
| # for i, nbest_hyp in enumerate(hyp.nbest[:5]): | |
| # score = nbest_hyp.score if hasattr(nbest_hyp, 'score') else 'N/A' | |
| # hyp_text = nbest_hyp.text if hasattr(nbest_hyp, 'text') else str(nbest_hyp) | |
| # print(f" {i+1}. [{score}] {hyp_text}") | |
| # return text | |
| # else: | |
| # # Fallback: convert to string | |
| # return str(hyp) | |
| # return "" | |
| # def transcribe_with_frame_analysis(self, audio_file: str): | |
| # """ | |
| # Analyze frame-level predictions to detect repetitions | |
| # This examines the raw CTC outputs before collapsing | |
| # """ | |
| # print("🔍 Performing frame-level analysis...") | |
| # # Get log probabilities at frame level | |
| # log_probs = self.asr_model.transcribe( | |
| # [audio_file], | |
| # batch_size=1, | |
| # logprobs=True | |
| # ) | |
| # # Standard transcription | |
| # transcription = self.asr_model.transcribe([audio_file]) | |
| # return transcription[0], log_probs | |
| # def transcribe_with_all_methods(self, audio_file: str): | |
| # """Try multiple decoding strategies and return all results""" | |
| # results = {} | |
| # # Method 1: Standard beam search | |
| # print("\n--- Method 1: Standard Beam Search ---") | |
| # results['beam_standard'] = self.transcribe_with_logprobs(audio_file, temperature=1.0) | |
| # # Method 2: Lower temperature (more conservative) | |
| # print("\n--- Method 2: Conservative (temp=0.5) ---") | |
| # results['beam_conservative'] = self.transcribe_with_logprobs(audio_file, temperature=0.5) | |
| # # Method 3: Higher temperature (more exploratory) | |
| # print("\n--- Method 3: Exploratory (temp=1.5) ---") | |
| # results['beam_exploratory'] = self.transcribe_with_logprobs(audio_file, temperature=1.5) | |
| # # Method 4: Frame-level analysis | |
| # print("\n--- Method 4: Frame-level Analysis ---") | |
| # results['frame_analysis'], _ = self.transcribe_with_frame_analysis(audio_file) | |
| # return results | |
| # def post_process_repetitions(text: str, audio_duration: float, expected_word_count: int = None) -> str: | |
| # """ | |
| # Heuristic post-processing to restore repetitions | |
| # Args: | |
| # text: Transcribed text | |
| # audio_duration: Duration of audio in seconds | |
| # expected_word_count: Expected number of words (if known) | |
| # """ | |
| # words = text.split() | |
| # # Calculate speaking rate (words per second) | |
| # speaking_rate = len(words) / audio_duration | |
| # # Normal Arabic speaking rate is 2-3 words per second | |
| # # For numbers, it's often slower (1-2 words per second) | |
| # # If rate is too high, likely missing repetitions | |
| # if speaking_rate > 3.0 and expected_word_count: | |
| # print(f"⚠️ Speaking rate unusually high ({speaking_rate:.1f} w/s)") | |
| # print(f" Expected ~{expected_word_count} words, got {len(words)}") | |
| # print(" Possible missing repetitions detected") | |
| # return text | |
| # def detect_number_patterns(text: str) -> List[str]: | |
| # """Detect if text contains Arabic number words""" | |
| # arabic_numbers = [ | |
| # 'صفر', 'زيرو', 'واحد', 'اثنين', 'ثلاثة', 'أربعة', | |
| # 'خمسة', 'ستة', 'سبعة', 'ثمانية', 'تسعة' | |
| # ] | |
| # words = text.split() | |
| # detected = [w for w in words if w in arabic_numbers] | |
| # if detected: | |
| # print(f"🔢 Detected number words: {' '.join(detected)}") | |
| # return detected | |
| # # ===== MAIN EXECUTION ===== | |
| # if __name__ == "__main__": | |
| # # ===== STEP 1: Record audio ===== | |
| # print("🎙️ Recording... Speak Arabic now!") | |
| # print("💡 TIP: For repeated numbers, pause slightly between each repetition") | |
| # print(" Example: 'زيرو [pause] زيرو [pause] واحد [pause] واحد'\n") | |
| # audio = sd.rec(int(SAMPLE_RATE * DURATION), samplerate=SAMPLE_RATE, channels=1, dtype='int16') | |
| # sd.wait() | |
| # wav.write(OUTPUT_FILE, SAMPLE_RATE, audio) | |
| # print(f"✅ Recording finished. Saved as {OUTPUT_FILE}\n") | |
| # # ===== STEP 2: Initialize transcriber ===== | |
| # model_path = "C:/Users/thegh/Python_Projects/Expertflow/UnderProgress/Arabic_Contextual_ASR/PreparingDatasetStreamlitApp/4_Finetuning_Nemo_ASR_arabic_names_and_complaints_for_phones/output_finetuned/finetuned_model_best.nemo" | |
| # transcriber = RepetitionAwareTranscriber(model_path) | |
| # # ===== STEP 3: Transcribe with all methods ===== | |
| # results = transcriber.transcribe_with_all_methods(OUTPUT_FILE) | |
| # # ===== STEP 4: Display all results ===== | |
| # print("\n" + "="*60) | |
| # print("📝 FINAL RESULTS:") | |
| # print("="*60) | |
| # for method, transcription in results.items(): | |
| # print(f"\n{method.upper()}:") | |
| # print(f" {transcription}") | |
| # detect_number_patterns(transcription) | |
| # # ===== STEP 5: Post-processing analysis ===== | |
| # print("\n" + "="*60) | |
| # print("🔍 POST-PROCESSING ANALYSIS:") | |
| # print("="*60) | |
| # best_transcription = results['beam_standard'] | |
| # processed = post_process_repetitions(best_transcription, DURATION) | |
| # print(f"\nBest transcription: {best_transcription}") | |
| # print(f"Word count: {len(best_transcription.split())}") | |
| # print(f"Speaking rate: {len(best_transcription.split()) / DURATION:.2f} words/sec") | |
| # # ===== STEP 6: Recommendations ===== | |
| # print("\n" + "="*60) | |
| # print("💡 RECOMMENDATIONS:") | |
| # print("="*60) | |
| # print("1. Compare all method outputs above") | |
| # print("2. If all methods miss repetitions, the issue is in the trained model") | |
| # print("3. Consider retraining with more repetitive sequences in training data") | |
| # print("4. When speaking, add slight pauses between repeated words") | |
| # print("5. If transcribing phone numbers, use digit-by-digit model instead") | |
| import sounddevice as sd | |
| import scipy.io.wavfile as wav | |
| import nemo.collections.asr as nemo_asr | |
| import torch | |
| import numpy as np | |
| from typing import List, Tuple | |
| # ===== SETTINGS ===== | |
| SAMPLE_RATE = 16000 | |
| DURATION = 10 # seconds | |
| OUTPUT_FILE = "arabic_recording.wav" | |
| class RepetitionAwareTranscriber: | |
| def __init__(self, model_path: str): | |
| """Initialize ASR model with repetition-aware configuration""" | |
| print("📥 Loading Arabic ASR model...") | |
| # Try to load as Hybrid RNNT-CTC first (better for repetitions!) | |
| try: | |
| self.asr_model = nemo_asr.models.EncDecHybridRNNTCTCModel.restore_from(model_path) | |
| self.model_type = "hybrid_rnnt_ctc" | |
| print("✅ Loaded as Hybrid RNNT-CTC model (excellent for repetitions!)") | |
| except: | |
| try: | |
| self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(model_path) | |
| self.model_type = "rnnt" | |
| print("✅ Loaded as RNNT model") | |
| except: | |
| self.asr_model = nemo_asr.models.EncDecCTCModel.restore_from(model_path) | |
| self.model_type = "ctc" | |
| print("✅ Loaded as CTC model") | |
| self._configure_decoding() | |
| def _configure_decoding(self): | |
| """Configure advanced decoding strategy""" | |
| decoding_cfg = self.asr_model.cfg.decoding | |
| # Use beam search for better sequence modeling | |
| decoding_cfg.strategy = "beam" | |
| decoding_cfg.beam.beam_size = 128 # Larger beam for more candidates | |
| decoding_cfg.beam.return_best_hypothesis = False # Get multiple hypotheses | |
| # Language model parameters (if available) | |
| if hasattr(decoding_cfg.beam, 'beam_alpha'): | |
| decoding_cfg.beam.beam_alpha = 0.3 # LM weight (lower = less LM influence) | |
| if hasattr(decoding_cfg.beam, 'beam_beta'): | |
| decoding_cfg.beam.beam_beta = 0.5 # Word insertion bonus | |
| self.asr_model.change_decoding_strategy(decoding_cfg) | |
| def transcribe_with_logprobs(self, audio_file: str, temperature: float = 1.0): | |
| """ | |
| Transcribe with log probabilities and temperature scaling | |
| Args: | |
| audio_file: Path to audio file | |
| temperature: Controls randomness (lower = more conservative, higher = more diverse) | |
| 0.5 = more deterministic | |
| 1.0 = standard | |
| 1.5 = more exploratory | |
| """ | |
| print(f"🔍 Transcribing with temperature={temperature}...") | |
| # Update temperature in decoding config | |
| if hasattr(self.asr_model.cfg.decoding, 'temperature'): | |
| self.asr_model.cfg.decoding.temperature = temperature | |
| if hasattr(self.asr_model.cfg.decoding.beam, 'softmax_temperature'): | |
| self.asr_model.cfg.decoding.beam.softmax_temperature = temperature | |
| self.asr_model.change_decoding_strategy(self.asr_model.cfg.decoding) | |
| # Get multiple hypotheses with their scores | |
| hypotheses = self.asr_model.transcribe( | |
| [audio_file], | |
| batch_size=1, | |
| return_hypotheses=True, | |
| num_workers=0 | |
| ) | |
| print(hypotheses) | |
| # Handle different return types | |
| if isinstance(hypotheses, list) and len(hypotheses) > 0: | |
| hyp = hypotheses[0] | |
| # Check if it's a Hypothesis object or a list | |
| if isinstance(hyp, list): | |
| # It's already a list of transcriptions | |
| best_text = hyp[0] if len(hyp) > 0 else "" | |
| print(f"\n📊 Top hypothesis: {best_text}") | |
| return best_text | |
| elif hasattr(hyp, 'text'): | |
| # It's a Hypothesis object | |
| text = hyp.text | |
| # Check for nbest hypotheses | |
| if hasattr(hyp, 'nbest') and len(hyp.nbest) > 1: | |
| print(f"\n📊 Top {min(5, len(hyp.nbest))} hypotheses:") | |
| for i, nbest_hyp in enumerate(hyp.nbest[:5]): | |
| score = nbest_hyp.score if hasattr(nbest_hyp, 'score') else 'N/A' | |
| hyp_text = nbest_hyp.text if hasattr(nbest_hyp, 'text') else str(nbest_hyp) | |
| print(f" {i+1}. [{score}] {hyp_text}") | |
| return text | |
| else: | |
| # Fallback: convert to string | |
| return str(hyp) | |
| return "" | |
| def transcribe_with_frame_analysis(self, audio_file: str): | |
| """ | |
| Analyze frame-level predictions to detect repetitions | |
| This examines the raw CTC outputs before collapsing | |
| """ | |
| print("🔍 Performing frame-level analysis...") | |
| # Get log probabilities at frame level | |
| log_probs = self.asr_model.transcribe( | |
| [audio_file], | |
| batch_size=1, | |
| logprobs=True | |
| ) | |
| # Standard transcription | |
| transcription = self.asr_model.transcribe([audio_file]) | |
| return transcription[0], log_probs | |
| def transcribe_with_all_methods(self, audio_file: str): | |
| """Try multiple decoding strategies and return all results""" | |
| results = {} | |
| # Method 1: Standard beam search | |
| print("\n--- Method 1: Standard Beam Search ---") | |
| results['beam_standard'] = self.transcribe_with_logprobs(audio_file, temperature=1.0) | |
| print(f"Results with Temp 1.0 : {results['beam_standard']}") | |
| # Method 2: Lower temperature (more conservative) | |
| print("\n--- Method 2: Conservative (temp=0.5) ---") | |
| results['beam_conservative'] = self.transcribe_with_logprobs(audio_file, temperature=0.5) | |
| print(f"Results with Temp 0.5 : {results['beam_conservative']}") | |
| # Method 3: Higher temperature (more exploratory) | |
| print("\n--- Method 3: Exploratory (temp=1.5) ---") | |
| results['beam_exploratory'] = self.transcribe_with_logprobs(audio_file, temperature=1.5) | |
| print(f"Results with Temp 1.5 : {results['beam_exploratory']}") | |
| # Method 4: Frame-level analysis | |
| # print("\n--- Method 4: Frame-level Analysis ---") | |
| # results['frame_analysis'], _ = self.transcribe_with_frame_analysis(audio_file) | |
| return results | |
| def post_process_repetitions(text: str, audio_duration: float, expected_word_count: int = None) -> str: | |
| """ | |
| Heuristic post-processing to restore repetitions | |
| Args: | |
| text: Transcribed text | |
| audio_duration: Duration of audio in seconds | |
| expected_word_count: Expected number of words (if known) | |
| """ | |
| words = text.split() | |
| # Calculate speaking rate (words per second) | |
| speaking_rate = len(words) / audio_duration | |
| # Normal Arabic speaking rate is 2-3 words per second | |
| # For numbers, it's often slower (1-2 words per second) | |
| # If rate is too high, likely missing repetitions | |
| if speaking_rate > 3.0 and expected_word_count: | |
| print(f"⚠️ Speaking rate unusually high ({speaking_rate:.1f} w/s)") | |
| print(f" Expected ~{expected_word_count} words, got {len(words)}") | |
| print(" Possible missing repetitions detected") | |
| return text | |
| def detect_number_patterns(text: str) -> List[str]: | |
| """Detect if text contains Arabic number words""" | |
| arabic_numbers = [ | |
| 'صفر', 'زيرو', 'واحد', 'اثنين', 'ثلاثة', 'أربعة', | |
| 'خمسة', 'ستة', 'سبعة', 'ثمانية', 'تسعة' | |
| ] | |
| words = text.split() | |
| detected = [w for w in words if w in arabic_numbers] | |
| if detected: | |
| print(f"🔢 Detected number words: {' '.join(detected)}") | |
| return detected | |
| # ===== MAIN EXECUTION ===== | |
| if __name__ == "__main__": | |
| # ===== STEP 1: Record audio ===== | |
| print("🎙️ Recording... Speak Arabic now!") | |
| print("💡 TIP: For repeated numbers, pause slightly between each repetition") | |
| print(" Example: 'زيرو [pause] زيرو [pause] واحد [pause] واحد'\n") | |
| audio = sd.rec(int(SAMPLE_RATE * DURATION), samplerate=SAMPLE_RATE, channels=1, dtype='int16') | |
| sd.wait() | |
| wav.write(OUTPUT_FILE, SAMPLE_RATE, audio) | |
| print(f"✅ Recording finished. Saved as {OUTPUT_FILE}\n") | |
| # ===== STEP 2: Initialize transcriber ===== | |
| model_path = "C:/Users/thegh/Python_Projects/Expertflow/UnderProgress/Arabic_Contextual_ASR/PreparingDatasetStreamlitApp/4_Finetuning_Nemo_ASR_arabic_names_and_complaints_for_phones/output_finetuned/finetuned_model_best.nemo" | |
| transcriber = RepetitionAwareTranscriber(model_path) | |
| # ===== STEP 3: Transcribe with all methods ===== | |
| results = transcriber.transcribe_with_all_methods(OUTPUT_FILE) | |
| # ===== STEP 4: Display all results ===== | |
| print("\n" + "="*60) | |
| print("📝 FINAL RESULTS:") | |
| print("="*60) | |
| for method, transcription in results.items(): | |
| print(f"\n{method.upper()}:") | |
| print(f" {transcription}") | |
| detect_number_patterns(transcription) | |
| # ===== STEP 5: Post-processing analysis ===== | |
| print("\n" + "="*60) | |
| print("🔍 POST-PROCESSING ANALYSIS:") | |
| print("="*60) | |
| best_transcription = results['beam_standard'] | |
| processed = post_process_repetitions(best_transcription, DURATION) | |
| print(f"\nBest transcription: {best_transcription}") | |
| print(f"Word count: {len(best_transcription.split())}") | |
| print(f"Speaking rate: {len(best_transcription.split()) / DURATION:.2f} words/sec") | |
| # ===== STEP 6: Recommendations ===== | |
| print("\n" + "="*60) | |
| print("💡 RECOMMENDATIONS:") | |
| print("="*60) | |
| print("1. Compare all method outputs above") | |
| print("2. If all methods miss repetitions, the issue is in the trained model") | |
| print("3. Consider retraining with more repetitive sequences in training data") | |
| print("4. When speaking, add slight pauses between repeated words") | |
| print("5. If transcribing phone numbers, use digit-by-digit model instead") |