Arabic_Finetuned_ASR_Nemo / testing_main_v2.py
alaatiger989's picture
Add files using upload-large-folder tool
b5e57ee verified
# import sounddevice as sd
# import scipy.io.wavfile as wav
# import nemo.collections.asr as nemo_asr
# import torch
# import numpy as np
# from typing import List, Tuple
# # ===== SETTINGS =====
# SAMPLE_RATE = 16000
# DURATION = 10 # seconds
# OUTPUT_FILE = "arabic_recording.wav"
# class RepetitionAwareTranscriber:
# def __init__(self, model_path: str):
# """Initialize ASR model with repetition-aware configuration"""
# print("📥 Loading Arabic ASR model...")
# self.asr_model = nemo_asr.models.EncDecCTCModel.restore_from(model_path)
# self._configure_decoding()
# def _configure_decoding(self):
# """Configure advanced decoding strategy"""
# decoding_cfg = self.asr_model.cfg.decoding
# # Use beam search for better sequence modeling
# decoding_cfg.strategy = "beam"
# decoding_cfg.beam.beam_size = 128 # Larger beam for more candidates
# decoding_cfg.beam.return_best_hypothesis = False # Get multiple hypotheses
# # Language model parameters (if available)
# if hasattr(decoding_cfg.beam, 'beam_alpha'):
# decoding_cfg.beam.beam_alpha = 0.3 # LM weight (lower = less LM influence)
# if hasattr(decoding_cfg.beam, 'beam_beta'):
# decoding_cfg.beam.beam_beta = 0.5 # Word insertion bonus
# self.asr_model.change_decoding_strategy(decoding_cfg)
# def transcribe_with_logprobs(self, audio_file: str, temperature: float = 1.0):
# """
# Transcribe with log probabilities and temperature scaling
# Args:
# audio_file: Path to audio file
# temperature: Controls randomness (lower = more conservative, higher = more diverse)
# 0.5 = more deterministic
# 1.0 = standard
# 1.5 = more exploratory
# """
# print(f"🔍 Transcribing with temperature={temperature}...")
# # Update temperature in decoding config
# if hasattr(self.asr_model.cfg.decoding, 'temperature'):
# self.asr_model.cfg.decoding.temperature = temperature
# if hasattr(self.asr_model.cfg.decoding.beam, 'softmax_temperature'):
# self.asr_model.cfg.decoding.beam.softmax_temperature = temperature
# self.asr_model.change_decoding_strategy(self.asr_model.cfg.decoding)
# # Get multiple hypotheses with their scores
# hypotheses = self.asr_model.transcribe(
# [audio_file],
# batch_size=1,
# return_hypotheses=True,
# num_workers=0
# )
# # Handle different return types
# if isinstance(hypotheses, list) and len(hypotheses) > 0:
# hyp = hypotheses[0]
# # Check if it's a Hypothesis object or a list
# if isinstance(hyp, list):
# # It's already a list of transcriptions
# best_text = hyp[0] if len(hyp) > 0 else ""
# print(f"\n📊 Top hypothesis: {best_text}")
# return best_text
# elif hasattr(hyp, 'text'):
# # It's a Hypothesis object
# text = hyp.text
# # Check for nbest hypotheses
# if hasattr(hyp, 'nbest') and len(hyp.nbest) > 1:
# print(f"\n📊 Top {min(5, len(hyp.nbest))} hypotheses:")
# for i, nbest_hyp in enumerate(hyp.nbest[:5]):
# score = nbest_hyp.score if hasattr(nbest_hyp, 'score') else 'N/A'
# hyp_text = nbest_hyp.text if hasattr(nbest_hyp, 'text') else str(nbest_hyp)
# print(f" {i+1}. [{score}] {hyp_text}")
# return text
# else:
# # Fallback: convert to string
# return str(hyp)
# return ""
# def transcribe_with_frame_analysis(self, audio_file: str):
# """
# Analyze frame-level predictions to detect repetitions
# This examines the raw CTC outputs before collapsing
# """
# print("🔍 Performing frame-level analysis...")
# # Get log probabilities at frame level
# log_probs = self.asr_model.transcribe(
# [audio_file],
# batch_size=1,
# logprobs=True
# )
# # Standard transcription
# transcription = self.asr_model.transcribe([audio_file])
# return transcription[0], log_probs
# def transcribe_with_all_methods(self, audio_file: str):
# """Try multiple decoding strategies and return all results"""
# results = {}
# # Method 1: Standard beam search
# print("\n--- Method 1: Standard Beam Search ---")
# results['beam_standard'] = self.transcribe_with_logprobs(audio_file, temperature=1.0)
# # Method 2: Lower temperature (more conservative)
# print("\n--- Method 2: Conservative (temp=0.5) ---")
# results['beam_conservative'] = self.transcribe_with_logprobs(audio_file, temperature=0.5)
# # Method 3: Higher temperature (more exploratory)
# print("\n--- Method 3: Exploratory (temp=1.5) ---")
# results['beam_exploratory'] = self.transcribe_with_logprobs(audio_file, temperature=1.5)
# # Method 4: Frame-level analysis
# print("\n--- Method 4: Frame-level Analysis ---")
# results['frame_analysis'], _ = self.transcribe_with_frame_analysis(audio_file)
# return results
# def post_process_repetitions(text: str, audio_duration: float, expected_word_count: int = None) -> str:
# """
# Heuristic post-processing to restore repetitions
# Args:
# text: Transcribed text
# audio_duration: Duration of audio in seconds
# expected_word_count: Expected number of words (if known)
# """
# words = text.split()
# # Calculate speaking rate (words per second)
# speaking_rate = len(words) / audio_duration
# # Normal Arabic speaking rate is 2-3 words per second
# # For numbers, it's often slower (1-2 words per second)
# # If rate is too high, likely missing repetitions
# if speaking_rate > 3.0 and expected_word_count:
# print(f"⚠️ Speaking rate unusually high ({speaking_rate:.1f} w/s)")
# print(f" Expected ~{expected_word_count} words, got {len(words)}")
# print(" Possible missing repetitions detected")
# return text
# def detect_number_patterns(text: str) -> List[str]:
# """Detect if text contains Arabic number words"""
# arabic_numbers = [
# 'صفر', 'زيرو', 'واحد', 'اثنين', 'ثلاثة', 'أربعة',
# 'خمسة', 'ستة', 'سبعة', 'ثمانية', 'تسعة'
# ]
# words = text.split()
# detected = [w for w in words if w in arabic_numbers]
# if detected:
# print(f"🔢 Detected number words: {' '.join(detected)}")
# return detected
# # ===== MAIN EXECUTION =====
# if __name__ == "__main__":
# # ===== STEP 1: Record audio =====
# print("🎙️ Recording... Speak Arabic now!")
# print("💡 TIP: For repeated numbers, pause slightly between each repetition")
# print(" Example: 'زيرو [pause] زيرو [pause] واحد [pause] واحد'\n")
# audio = sd.rec(int(SAMPLE_RATE * DURATION), samplerate=SAMPLE_RATE, channels=1, dtype='int16')
# sd.wait()
# wav.write(OUTPUT_FILE, SAMPLE_RATE, audio)
# print(f"✅ Recording finished. Saved as {OUTPUT_FILE}\n")
# # ===== STEP 2: Initialize transcriber =====
# model_path = "C:/Users/thegh/Python_Projects/Expertflow/UnderProgress/Arabic_Contextual_ASR/PreparingDatasetStreamlitApp/4_Finetuning_Nemo_ASR_arabic_names_and_complaints_for_phones/output_finetuned/finetuned_model_best.nemo"
# transcriber = RepetitionAwareTranscriber(model_path)
# # ===== STEP 3: Transcribe with all methods =====
# results = transcriber.transcribe_with_all_methods(OUTPUT_FILE)
# # ===== STEP 4: Display all results =====
# print("\n" + "="*60)
# print("📝 FINAL RESULTS:")
# print("="*60)
# for method, transcription in results.items():
# print(f"\n{method.upper()}:")
# print(f" {transcription}")
# detect_number_patterns(transcription)
# # ===== STEP 5: Post-processing analysis =====
# print("\n" + "="*60)
# print("🔍 POST-PROCESSING ANALYSIS:")
# print("="*60)
# best_transcription = results['beam_standard']
# processed = post_process_repetitions(best_transcription, DURATION)
# print(f"\nBest transcription: {best_transcription}")
# print(f"Word count: {len(best_transcription.split())}")
# print(f"Speaking rate: {len(best_transcription.split()) / DURATION:.2f} words/sec")
# # ===== STEP 6: Recommendations =====
# print("\n" + "="*60)
# print("💡 RECOMMENDATIONS:")
# print("="*60)
# print("1. Compare all method outputs above")
# print("2. If all methods miss repetitions, the issue is in the trained model")
# print("3. Consider retraining with more repetitive sequences in training data")
# print("4. When speaking, add slight pauses between repeated words")
# print("5. If transcribing phone numbers, use digit-by-digit model instead")
import sounddevice as sd
import scipy.io.wavfile as wav
import nemo.collections.asr as nemo_asr
import torch
import numpy as np
from typing import List, Tuple
# ===== SETTINGS =====
SAMPLE_RATE = 16000
DURATION = 10 # seconds
OUTPUT_FILE = "arabic_recording.wav"
class RepetitionAwareTranscriber:
def __init__(self, model_path: str):
"""Initialize ASR model with repetition-aware configuration"""
print("📥 Loading Arabic ASR model...")
# Try to load as Hybrid RNNT-CTC first (better for repetitions!)
try:
self.asr_model = nemo_asr.models.EncDecHybridRNNTCTCModel.restore_from(model_path)
self.model_type = "hybrid_rnnt_ctc"
print("✅ Loaded as Hybrid RNNT-CTC model (excellent for repetitions!)")
except:
try:
self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(model_path)
self.model_type = "rnnt"
print("✅ Loaded as RNNT model")
except:
self.asr_model = nemo_asr.models.EncDecCTCModel.restore_from(model_path)
self.model_type = "ctc"
print("✅ Loaded as CTC model")
self._configure_decoding()
def _configure_decoding(self):
"""Configure advanced decoding strategy"""
decoding_cfg = self.asr_model.cfg.decoding
# Use beam search for better sequence modeling
decoding_cfg.strategy = "beam"
decoding_cfg.beam.beam_size = 128 # Larger beam for more candidates
decoding_cfg.beam.return_best_hypothesis = False # Get multiple hypotheses
# Language model parameters (if available)
if hasattr(decoding_cfg.beam, 'beam_alpha'):
decoding_cfg.beam.beam_alpha = 0.3 # LM weight (lower = less LM influence)
if hasattr(decoding_cfg.beam, 'beam_beta'):
decoding_cfg.beam.beam_beta = 0.5 # Word insertion bonus
self.asr_model.change_decoding_strategy(decoding_cfg)
def transcribe_with_logprobs(self, audio_file: str, temperature: float = 1.0):
"""
Transcribe with log probabilities and temperature scaling
Args:
audio_file: Path to audio file
temperature: Controls randomness (lower = more conservative, higher = more diverse)
0.5 = more deterministic
1.0 = standard
1.5 = more exploratory
"""
print(f"🔍 Transcribing with temperature={temperature}...")
# Update temperature in decoding config
if hasattr(self.asr_model.cfg.decoding, 'temperature'):
self.asr_model.cfg.decoding.temperature = temperature
if hasattr(self.asr_model.cfg.decoding.beam, 'softmax_temperature'):
self.asr_model.cfg.decoding.beam.softmax_temperature = temperature
self.asr_model.change_decoding_strategy(self.asr_model.cfg.decoding)
# Get multiple hypotheses with their scores
hypotheses = self.asr_model.transcribe(
[audio_file],
batch_size=1,
return_hypotheses=True,
num_workers=0
)
print(hypotheses)
# Handle different return types
if isinstance(hypotheses, list) and len(hypotheses) > 0:
hyp = hypotheses[0]
# Check if it's a Hypothesis object or a list
if isinstance(hyp, list):
# It's already a list of transcriptions
best_text = hyp[0] if len(hyp) > 0 else ""
print(f"\n📊 Top hypothesis: {best_text}")
return best_text
elif hasattr(hyp, 'text'):
# It's a Hypothesis object
text = hyp.text
# Check for nbest hypotheses
if hasattr(hyp, 'nbest') and len(hyp.nbest) > 1:
print(f"\n📊 Top {min(5, len(hyp.nbest))} hypotheses:")
for i, nbest_hyp in enumerate(hyp.nbest[:5]):
score = nbest_hyp.score if hasattr(nbest_hyp, 'score') else 'N/A'
hyp_text = nbest_hyp.text if hasattr(nbest_hyp, 'text') else str(nbest_hyp)
print(f" {i+1}. [{score}] {hyp_text}")
return text
else:
# Fallback: convert to string
return str(hyp)
return ""
def transcribe_with_frame_analysis(self, audio_file: str):
"""
Analyze frame-level predictions to detect repetitions
This examines the raw CTC outputs before collapsing
"""
print("🔍 Performing frame-level analysis...")
# Get log probabilities at frame level
log_probs = self.asr_model.transcribe(
[audio_file],
batch_size=1,
logprobs=True
)
# Standard transcription
transcription = self.asr_model.transcribe([audio_file])
return transcription[0], log_probs
def transcribe_with_all_methods(self, audio_file: str):
"""Try multiple decoding strategies and return all results"""
results = {}
# Method 1: Standard beam search
print("\n--- Method 1: Standard Beam Search ---")
results['beam_standard'] = self.transcribe_with_logprobs(audio_file, temperature=1.0)
print(f"Results with Temp 1.0 : {results['beam_standard']}")
# Method 2: Lower temperature (more conservative)
print("\n--- Method 2: Conservative (temp=0.5) ---")
results['beam_conservative'] = self.transcribe_with_logprobs(audio_file, temperature=0.5)
print(f"Results with Temp 0.5 : {results['beam_conservative']}")
# Method 3: Higher temperature (more exploratory)
print("\n--- Method 3: Exploratory (temp=1.5) ---")
results['beam_exploratory'] = self.transcribe_with_logprobs(audio_file, temperature=1.5)
print(f"Results with Temp 1.5 : {results['beam_exploratory']}")
# Method 4: Frame-level analysis
# print("\n--- Method 4: Frame-level Analysis ---")
# results['frame_analysis'], _ = self.transcribe_with_frame_analysis(audio_file)
return results
def post_process_repetitions(text: str, audio_duration: float, expected_word_count: int = None) -> str:
"""
Heuristic post-processing to restore repetitions
Args:
text: Transcribed text
audio_duration: Duration of audio in seconds
expected_word_count: Expected number of words (if known)
"""
words = text.split()
# Calculate speaking rate (words per second)
speaking_rate = len(words) / audio_duration
# Normal Arabic speaking rate is 2-3 words per second
# For numbers, it's often slower (1-2 words per second)
# If rate is too high, likely missing repetitions
if speaking_rate > 3.0 and expected_word_count:
print(f"⚠️ Speaking rate unusually high ({speaking_rate:.1f} w/s)")
print(f" Expected ~{expected_word_count} words, got {len(words)}")
print(" Possible missing repetitions detected")
return text
def detect_number_patterns(text: str) -> List[str]:
"""Detect if text contains Arabic number words"""
arabic_numbers = [
'صفر', 'زيرو', 'واحد', 'اثنين', 'ثلاثة', 'أربعة',
'خمسة', 'ستة', 'سبعة', 'ثمانية', 'تسعة'
]
words = text.split()
detected = [w for w in words if w in arabic_numbers]
if detected:
print(f"🔢 Detected number words: {' '.join(detected)}")
return detected
# ===== MAIN EXECUTION =====
if __name__ == "__main__":
# ===== STEP 1: Record audio =====
print("🎙️ Recording... Speak Arabic now!")
print("💡 TIP: For repeated numbers, pause slightly between each repetition")
print(" Example: 'زيرو [pause] زيرو [pause] واحد [pause] واحد'\n")
audio = sd.rec(int(SAMPLE_RATE * DURATION), samplerate=SAMPLE_RATE, channels=1, dtype='int16')
sd.wait()
wav.write(OUTPUT_FILE, SAMPLE_RATE, audio)
print(f"✅ Recording finished. Saved as {OUTPUT_FILE}\n")
# ===== STEP 2: Initialize transcriber =====
model_path = "C:/Users/thegh/Python_Projects/Expertflow/UnderProgress/Arabic_Contextual_ASR/PreparingDatasetStreamlitApp/4_Finetuning_Nemo_ASR_arabic_names_and_complaints_for_phones/output_finetuned/finetuned_model_best.nemo"
transcriber = RepetitionAwareTranscriber(model_path)
# ===== STEP 3: Transcribe with all methods =====
results = transcriber.transcribe_with_all_methods(OUTPUT_FILE)
# ===== STEP 4: Display all results =====
print("\n" + "="*60)
print("📝 FINAL RESULTS:")
print("="*60)
for method, transcription in results.items():
print(f"\n{method.upper()}:")
print(f" {transcription}")
detect_number_patterns(transcription)
# ===== STEP 5: Post-processing analysis =====
print("\n" + "="*60)
print("🔍 POST-PROCESSING ANALYSIS:")
print("="*60)
best_transcription = results['beam_standard']
processed = post_process_repetitions(best_transcription, DURATION)
print(f"\nBest transcription: {best_transcription}")
print(f"Word count: {len(best_transcription.split())}")
print(f"Speaking rate: {len(best_transcription.split()) / DURATION:.2f} words/sec")
# ===== STEP 6: Recommendations =====
print("\n" + "="*60)
print("💡 RECOMMENDATIONS:")
print("="*60)
print("1. Compare all method outputs above")
print("2. If all methods miss repetitions, the issue is in the trained model")
print("3. Consider retraining with more repetitive sequences in training data")
print("4. When speaking, add slight pauses between repeated words")
print("5. If transcribing phone numbers, use digit-by-digit model instead")