ASR / src /eval /ScoreCalcs.py
MihirRPatil's picture
deploy: CDAC ASR backend with pitch/stress fix and LLM feedback
88a679b
Raw
History Blame Contribute Delete
21.7 kB
import Levenshtein
from typing import List, Tuple, Dict
import numpy as np
from fastdtw import fastdtw
import torch
import torchaudio.functional as F
class PronunciationScorer:
def __init__(self):
self.weights = {
'phoneme': 0.5,
'duration': 0.2,
'stress': 0.2,
'pitch': 0.1
}
def _get_alignment_ops(self, pred: List[str], ref: List[str]) -> List[Tuple[str, str]]:
"""
Returns aligned phoneme pairs with gaps marked as '-'
using Levenshtein edit operations
"""
aligned = []
i, j = 0, 0
ops = Levenshtein.editops(ref, pred)
for op in ops:
# Add matching phonemes before this edit
while i < op[1] and j < op[2]:
aligned.append((pred[j], ref[i]))
i += 1
j += 1
# Handle the edit operation
if op[0] == 'replace':
aligned.append((pred[op[2]], ref[op[1]]))
i += 1
j += 1
elif op[0] == 'delete':
aligned.append(('-', ref[op[1]]))
i += 1
elif op[0] == 'insert':
aligned.append((pred[op[2]], '-'))
j += 1
# Add remaining matching phonemes
while i < len(ref) and j < len(pred):
aligned.append((pred[j], ref[i]))
i += 1
j += 1
return aligned
def phoneme_accuracy(self, pred: List[str], ref: List[str]) -> Tuple[float, List[Tuple[str, str]]]:
"""
Returns:
- accuracy score (0-1)
- aligned phoneme pairs with gaps
"""
aligned = self._get_alignment_ops(pred, ref)
correct = sum(1 for p, r in aligned if p == r)
total_ref = len([r for _, r in aligned if r != '-'])
return (correct / total_ref) if total_ref > 0 else 0.0, aligned
def get_error_stats(self, aligned: List[Tuple[str, str]]) -> Dict[str, int]:
"""Returns counts of substitutions, insertions, deletions"""
stats = {'sub': 0, 'ins': 0, 'del': 0}
for p, r in aligned:
if p == '-' and r != '-':
stats['del'] += 1
elif p != '-' and r == '-':
stats['ins'] += 1
elif p != r:
stats['sub'] += 1
return stats
def duration_score(self,
pred_times: List[Tuple[float, float]],
ref_times: List[Tuple[float, float]],
aligned_pairs: List[Tuple[str, str]]) -> Dict[str, float]:
"""
Calculate duration metrics for aligned phonemes
Returns:
{
'accuracy': 0-1 score,
'avg_ratio': average duration ratio,
'error_ms': average error in milliseconds
}
"""
if not pred_times or not ref_times:
return {'accuracy': 0.0, 'avg_ratio': 1.0, 'error_ms': 0.0}
scores = []
ratios = []
errors = []
pred_idx, ref_idx = 0, 0
for p_phn, r_phn in aligned_pairs:
# Only compare when both phonemes exist
if p_phn != '-' and r_phn != '-':
p_start, p_end = pred_times[pred_idx]
r_start, r_end = ref_times[ref_idx]
p_dur = p_end - p_start
r_dur = r_end - r_start
if r_dur > 0:
ratio = p_dur / r_dur
ratios.append(ratio)
errors.append(abs(p_dur - r_dur) * 1000)
# Accuracy score (1 - normalized error)
norm_error = min(1, abs(1 - ratio))
scores.append(1 - norm_error)
pred_idx += 1
ref_idx += 1
else:
if p_phn == '-': ref_idx += 1
if r_phn == '-': pred_idx += 1
if not scores:
return {'accuracy': 0.0, 'avg_ratio': 1.0, 'error_ms': 0.0}
return {
'accuracy': sum(scores) / len(scores),
'avg_ratio': sum(ratios) / len(ratios),
'error_ms': sum(errors) / len(errors)
}
def _extract_pitch_contour(self, waveform, sr, phoneme_times):
"""Extract pitch using a robust autocorrelation algorithm (avoids Numba/LLVM segfaults)"""
pitch_contours = []
try:
# Ensure waveform is 1D numpy array
if isinstance(waveform, np.ndarray):
if waveform.ndim > 1:
waveform = waveform.squeeze()
# Autocorrelation-based pitch tracking (pure NumPy)
hop_length = 512
min_lag = int(sr / 500)
max_lag = int(sr / 50)
num_samples = len(waveform)
f0 = []
for start_sample in range(0, num_samples - hop_length, hop_length):
frame = waveform[start_sample:start_sample + hop_length]
# Zero-mean the frame
frame = frame - np.mean(frame)
if np.std(frame) < 1e-4:
f0.append(0.0)
continue
corr = np.correlate(frame, frame, mode='full')
corr = corr[len(corr)//2:]
if len(corr) > max_lag:
search_region = corr[min_lag:max_lag]
if len(search_region) > 0:
peak_lag = np.argmax(search_region) + min_lag
pitch = sr / peak_lag
if corr[peak_lag] > 0.3 * corr[0]:
f0.append(pitch)
else:
f0.append(0.0)
else:
f0.append(0.0)
else:
f0.append(0.0)
f0 = np.array(f0)
# Extract per-phoneme segments
for start, end in phoneme_times:
start_idx = int(start * sr / hop_length)
end_idx = int(end * sr / hop_length)
segment = f0[start_idx:end_idx]
# Filter out unvoiced frames (0 values)
segment_voiced = segment[segment > 0]
pitch_contours.append(segment_voiced)
except Exception as e:
print(f"Pitch extraction error: {e}")
return []
return pitch_contours
def _extract_continuous_pitch(self, waveform, sr) -> List[float]:
"""
Extracts continuous pitch contour (list of Hz values) at 20ms frames (320 samples).
"""
try:
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
if isinstance(waveform, np.ndarray):
waveform = waveform.squeeze()
hop_length = 320
frame_size = 512
min_lag = int(sr / 500)
max_lag = int(sr / 50)
num_samples = len(waveform)
f0 = []
for start_sample in range(0, num_samples - frame_size, hop_length):
frame = waveform[start_sample:start_sample + frame_size]
frame = frame - np.mean(frame)
if np.std(frame) < 1e-4:
f0.append(0.0)
continue
corr = np.correlate(frame, frame, mode='full')
corr = corr[len(corr)//2:]
if len(corr) > max_lag:
search_region = corr[min_lag:max_lag]
if len(search_region) > 0:
peak_lag = np.argmax(search_region) + min_lag
pitch = sr / peak_lag
if corr[peak_lag] > 0.25 * corr[0]:
f0.append(float(pitch))
else:
f0.append(0.0)
else:
f0.append(0.0)
else:
f0.append(0.0)
# Map 0.0 values to None for clean JSON serialization
f0_clean = [val if val > 0.0 else None for val in f0]
return f0_clean
except Exception as e:
print(f"Error in continuous pitch extraction: {e}")
return []
def pitch_score(self, pred_waveform, ref_waveform, sr, aligned_pairs, pred_times, ref_times):
"""Compare pitch contours using DTW and return trajectories"""
# Ensure waveforms are numpy arrays
if not isinstance(pred_waveform, np.ndarray):
pred_waveform = np.array(pred_waveform)
if not isinstance(ref_waveform, np.ndarray):
ref_waveform = np.array(ref_waveform)
# Squeeze to 1D if needed
pred_waveform = pred_waveform.squeeze()
ref_waveform = ref_waveform.squeeze()
# Calculate continuous trajectories for visualization
trajectory = self._extract_continuous_pitch(pred_waveform, sr)
reference_trajectory = self._extract_continuous_pitch(ref_waveform, sr)
# Filter out None values to get clean voiced trajectories
pred_voiced = np.array([p for p in trajectory if p is not None], dtype=np.float32)
ref_voiced = np.array([r for r in reference_trajectory if r is not None], dtype=np.float32)
similarity = 0.8 # default baseline
correlation = 0.8
error_hz = 0.0
if len(pred_voiced) > 3 and len(ref_voiced) > 3:
try:
# Normalize pitch to z-scores to compare relative shape rather than absolute register (male vs female)
pred_std = np.std(pred_voiced)
ref_std = np.std(ref_voiced)
pred_norm = (pred_voiced - np.mean(pred_voiced)) / (pred_std if pred_std > 1e-4 else 1e-4)
ref_norm = (ref_voiced - np.mean(ref_voiced)) / (ref_std if ref_std > 1e-4 else 1e-4)
dtw_dist, _ = fastdtw(pred_norm, ref_norm)
# Normalize DTW distance to 0-1 similarity based on size
norm_factor = max(len(pred_norm), len(ref_norm))
similarity = 1 / (1 + (dtw_dist / (norm_factor if norm_factor > 0 else 1.0)))
# Absolute difference in mean pitch
error_hz = abs(np.mean(pred_voiced) - np.mean(ref_voiced))
# Correlation of truncated/aligned sequences
min_len = min(len(pred_voiced), len(ref_voiced))
if min_len > 1:
corr = np.corrcoef(pred_voiced[:min_len], ref_voiced[:min_len])[0, 1]
correlation = 0.0 if np.isnan(corr) else corr
except Exception as e:
print(f"Warning: pitch score computation failed: {e}")
return {
'similarity': float(similarity),
'error_hz': float(error_hz),
'correlation': float(correlation),
'trajectory': trajectory,
'reference_trajectory': reference_trajectory
}
def _extract_energy_envelope(self, waveform, sr) -> List[float]:
"""
Extracts the RMS energy envelope of the waveform at 20ms frames.
"""
try:
if isinstance(waveform, torch.Tensor):
waveform = waveform.cpu().numpy()
if isinstance(waveform, np.ndarray):
waveform = waveform.squeeze()
hop_length = 320 # 20ms frames
frame_size = 512
num_samples = len(waveform)
energy = []
for start_sample in range(0, num_samples - frame_size, hop_length):
frame = waveform[start_sample:start_sample + frame_size]
rms = np.sqrt(np.mean(frame**2))
energy.append(float(rms))
# Normalize to 0-1 range to align scale
energy = np.array(energy, dtype=np.float32)
if len(energy) > 0:
max_val = np.max(energy)
if max_val > 1e-6:
energy = energy / max_val
return energy.tolist()
except Exception as e:
print(f"Error extracting energy envelope: {e}")
return []
def stress_score(self, pred_waveform, ref_waveform, sr) -> Dict[str, float]:
"""
Compare dynamic stress (energy/loudness intensity envelopes) between prediction and reference.
"""
try:
if not isinstance(pred_waveform, np.ndarray):
pred_waveform = np.array(pred_waveform)
if not isinstance(ref_waveform, np.ndarray):
ref_waveform = np.array(ref_waveform)
pred_waveform = pred_waveform.squeeze()
ref_waveform = ref_waveform.squeeze()
pred_energy = self._extract_energy_envelope(pred_waveform, sr)
ref_energy = self._extract_energy_envelope(ref_waveform, sr)
similarity = 0.8 # default baseline fallback
if len(pred_energy) > 5 and len(ref_energy) > 5:
dtw_dist, _ = fastdtw(np.array(pred_energy), np.array(ref_energy))
norm_factor = max(len(pred_energy), len(ref_energy))
similarity = 1 / (1 + (dtw_dist / (norm_factor if norm_factor > 0 else 1.0)))
return {
'accuracy': float(similarity),
'error_stats': {
'missing_stress': 0,
'extra_stress': 0,
'wrong_stress': 0
}
}
except Exception as e:
print(f"Error computing stress score: {e}")
return {'accuracy': 0.8, 'error_stats': {}}
def compute_scores(self,
pred_phonemes: List[str],
ref_phonemes: List[str],
pred_times: List[Tuple[float, float]] = None,
ref_times: List[Tuple[float, float]] = None,
pred_waveform = None,
ref_waveform = None,
sr: int = None) -> Dict:
"""Enhanced scoring interface with pitch and stress analysis"""
accuracy, aligned = self.phoneme_accuracy(pred_phonemes, ref_phonemes)
results = {
'phoneme': accuracy,
'error_stats': self.get_error_stats(aligned),
'aligned_pairs': aligned
}
if pred_times and ref_times:
results['duration'] = self.duration_score(pred_times, ref_times, aligned)
if all(x is not None for x in [pred_waveform, ref_waveform, sr]):
results['pitch'] = self.pitch_score(
pred_waveform, ref_waveform, sr, aligned, pred_times, ref_times)
results['stress'] = self.stress_score(
pred_waveform, ref_waveform, sr)
return results
def ctc_forced_align(self, log_probs: torch.Tensor, targets: torch.Tensor, blank_id: int = 0) -> List[Tuple[int, int]]:
"""
Computes CTC forced alignment for batch_size=1.
Args:
log_probs: Tensor of shape (1, Time, Vocab)
targets: Tensor of shape (1, Target_Len)
blank_id: Index of blank token
Returns:
List of (start_frame, end_frame) matching each token in targets.
"""
B, T, C = log_probs.shape
L = targets.shape[1]
# Move inputs to CPU to avoid CUDA kernel/driver binary compatibility segfaults
# and multi-GPU device mapping issues in torchaudio's C++ extension.
log_probs_cpu = log_probs.cpu()
targets_cpu = targets.cpu()
targets_list = targets_cpu[0].numpy().tolist()
# Validate constraints to prevent C++ out-of-bounds/assertion crashes
# 1. Target sequence cannot be empty
# 2. Input frames must be >= target length
# 3. Target sequence must not contain the blank/pad token
if L == 0 or T < L or blank_id in targets_list:
print(f"Warning: CTC alignment constraints violated (T={T}, L={L}, blank_in_target={blank_id in targets_list}). Falling back to linear alignment.")
intervals = []
step = T / max(L, 1)
for idx in range(L):
s = int(idx * step)
e = int((idx + 1) * step) - 1
intervals.append((s, max(s, e)))
return intervals
input_lengths = torch.tensor([T], dtype=torch.long, device="cpu")
target_lengths = torch.tensor([L], dtype=torch.long, device="cpu")
# Log softmax along vocab dimension
log_probs_norm = torch.log_softmax(log_probs_cpu, dim=-1)
try:
# torchaudio forced_align on CPU
alignments, scores = F.forced_align(
log_probs_norm,
targets_cpu,
input_lengths=input_lengths,
target_lengths=target_lengths,
blank=blank_id
)
path = alignments[0].numpy().tolist()
# Extract intervals using state machine
intervals = []
target_idx = 0
start_frame = None
end_frame = None
saw_blank = False
for t in range(T):
token = path[t]
if token == blank_id:
saw_blank = True
continue
if (target_idx + 1 < L and token == targets_list[target_idx + 1] and
start_frame is not None and
(targets_list[target_idx + 1] != targets_list[target_idx] or saw_blank)):
intervals.append((start_frame, end_frame))
target_idx += 1
start_frame = t
end_frame = t
saw_blank = False
elif target_idx < L and token == targets_list[target_idx]:
if start_frame is None:
start_frame = t
end_frame = t
saw_blank = False
if start_frame is not None:
intervals.append((start_frame, end_frame))
except Exception as e:
print(f"Warning: torchaudio forced_align failed: {e}. Falling back to linear alignment.")
intervals = []
step = T / max(L, 1)
for idx in range(L):
s = int(idx * step)
e = int((idx + 1) * step) - 1
intervals.append((s, max(s, e)))
return intervals
# Fallback padding
while len(intervals) < L:
if intervals:
intervals.append(intervals[-1])
else:
intervals.append((0, T - 1))
return intervals[:L]
def compute_gop(self,
log_probs: torch.Tensor,
targets: torch.Tensor,
intervals: List[Tuple[int, int]],
vocab_tokens: List[str],
blank_id: int = 0) -> List[Dict]:
"""
Computes Goodness of Pronunciation (GoP) using max-pooling and blank-exclusion.
"""
# Argmax predictions across all frames to identify blank frames
pred_ids = torch.argmax(log_probs[0], dim=-1).cpu().numpy()
probs = torch.softmax(log_probs[0], dim=-1)
L = targets.shape[1]
targets_list = targets[0].cpu().numpy().tolist()
results = []
frame_stride_ms = 20.0
for idx in range(L):
token_id = targets_list[idx]
phoneme = vocab_tokens[idx] if idx < len(vocab_tokens) else str(token_id)
s_frame, e_frame = intervals[idx]
# Blank-Exclusion: Filter out frames where argmax prediction is <pad> (blank_id)
valid_frames = []
for f in range(s_frame, e_frame + 1):
if pred_ids[f] != blank_id:
valid_frames.append(f)
# If all frames in segment are blank, fall back to evaluating all frames in the segment
if not valid_frames:
valid_frames = list(range(s_frame, e_frame + 1))
token_probs = probs[valid_frames, token_id]
# Max-Pooling: Take the maximum probability inside the valid segment frames
if len(token_probs) > 0:
gop_prob = float(torch.max(token_probs).item())
else:
gop_prob = 1e-8
is_correct = bool(gop_prob >= 0.40)
results.append({
"phoneme": phoneme,
"start_ms": float(s_frame * frame_stride_ms),
"end_ms": float((e_frame + 1) * frame_stride_ms),
"gop_prob": gop_prob,
"is_correct": is_correct
})
return results