Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import json | |
| import argparse | |
| import soundfile as sf | |
| import io | |
| import numpy as np | |
| import traceback | |
| import re | |
| import sys | |
| from datasets import load_dataset, Audio | |
| from transformers import Wav2Vec2Processor | |
| from src.models.phoneme_embedder import Wav2Vec2PhonemeEmbedder | |
| from src.g2p.g2p_utils import G2PManager | |
| from src.utils.audio_utils import AudioPreprocessor | |
| from collections import OrderedDict | |
| def calculate_per(reference, hypothesis): | |
| """Memory-efficient Levenshtein distance for PER.""" | |
| nr = len(reference) | |
| nh = len(hypothesis) | |
| if nr == 0: return nh | |
| if nh == 0: return nr | |
| row = np.arange(nh + 1) | |
| for i in range(1, nr + 1): | |
| prev_row = row.copy() | |
| row[0] = i | |
| for j in range(1, nh + 1): | |
| cost = 0 if reference[i-1] == hypothesis[j-1] else 1 | |
| row[j] = min(prev_row[j] + 1, # deletion | |
| row[j-1] + 1, # insertion | |
| prev_row[j-1] + cost) # substitution | |
| return row[nh] / nr | |
| def main(): | |
| # Ensure NLTK resources are available for G2P | |
| import nltk | |
| print("Checking NLTK resources...", flush=True) | |
| for res in ['averaged_perceptron_tagger', 'averaged_perceptron_tagger_eng', 'cmudict', 'punkt', 'punkt_tab']: | |
| try: | |
| nltk.download(res, quiet=True) | |
| except Exception: | |
| pass | |
| parser = argparse.ArgumentParser(description="Evaluate Phoneme Embedder on NPTEL dataset") | |
| parser.add_argument("--model_dir", default="trained_models/20k_steps", help="Path to model directory") | |
| parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to evaluate") | |
| parser.add_argument("--split", default="train", help="Dataset split") | |
| parser.add_argument("--skip", type=int, default=50000, help="Skip first N samples") | |
| parser.add_argument("--sanity_check", action="store_true", help="Run on training data (skip=0) to verify weights") | |
| args = parser.parse_args() | |
| if args.sanity_check: | |
| print("🔍 SANITY CHECK MODE: Reverting skip to 0 to test training data.") | |
| args.skip = 0 | |
| print(f"Loading model from {args.model_dir}...", flush=True) | |
| processor = Wav2Vec2Processor.from_pretrained(args.model_dir) | |
| model = Wav2Vec2PhonemeEmbedder.from_pretrained(args.model_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| # Initialize G2P Manager and Audio Preprocessor (SYNC WITH TRAINING) | |
| print("Initializing components (G2P and AudioPreprocessor)...", flush=True) | |
| g2p_manager = G2PManager() | |
| preprocessor_utils = AudioPreprocessor(sr=16000) | |
| # Load ID to Phoneme mapping | |
| vocab_path = os.path.join(args.model_dir, "vocab.json") | |
| with open(vocab_path, "r", encoding="utf8") as f: | |
| vocab = json.load(f) | |
| id2phoneme = {v: k for k, v in vocab.items()} | |
| pad_id = processor.tokenizer.pad_token_id | |
| print(f"Loading dataset skbose/indian-english-nptel-v0 (streaming)...", flush=True) | |
| ds = load_dataset("skbose/indian-english-nptel-v0", split=args.split, streaming=True) | |
| ds = ds.cast_column("audio", Audio(decode=False)) | |
| print(f"Skipping {args.skip} samples...", flush=True) | |
| eval_iterable = ds.skip(args.skip).take(args.num_samples) | |
| total_per = 0 | |
| count = 0 | |
| iterator = iter(eval_iterable) | |
| print(f"Evaluating {args.num_samples} samples...", flush=True) | |
| for i in range(args.num_samples): | |
| try: | |
| # 0. Get Sample | |
| try: | |
| sample = next(iterator) | |
| except (RuntimeError, Exception): | |
| print(f"\n⚠️ HF Error at step {i}, re-initializing iterator...", flush=True) | |
| ds_reinit = load_dataset("skbose/indian-english-nptel-v0", split=args.split, streaming=True) | |
| ds_reinit = ds_reinit.cast_column("audio", Audio(decode=False)) | |
| iterator = iter(ds_reinit.skip(args.skip + i).take(args.num_samples - i)) | |
| sample = next(iterator) | |
| # 1. Decode audio | |
| audio_bytes = sample["audio"]["bytes"] | |
| with io.BytesIO(audio_bytes) as f: | |
| audio_array, sr = sf.read(f) | |
| # Resample to 16kHz if needed (Matching train_streaming.py) | |
| if sr != 16000: | |
| import librosa | |
| audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000) | |
| # 1.1 SYNC PREPROCESSING: FFT Filter + VAD Trim | |
| audio_data = preprocessor_utils.preprocess(audio_array) | |
| # Skip extremely long audio | |
| if len(audio_data) / 16000 > 30: | |
| continue | |
| # Ensure float32 | |
| audio_data = audio_data.astype(np.float32) | |
| # 2. Preprocess with Processor (Group Norm happens here) | |
| inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt", padding=True) | |
| input_values = inputs.input_values.to(device) | |
| # 3. Inference | |
| with torch.no_grad(): | |
| outputs = model(input_values=input_values, return_dict=True) | |
| if isinstance(outputs, (dict, OrderedDict)): | |
| logits = outputs.get("logits") | |
| elif hasattr(outputs, "logits"): | |
| logits = outputs.logits | |
| else: | |
| logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs | |
| pred_ids = torch.argmax(logits, dim=-1)[0].cpu().numpy() | |
| # 4. Collapse CTC | |
| collapsed = [] | |
| prev = None | |
| unk_count = 0 | |
| for pid in pred_ids: | |
| if pid != pad_id: | |
| if pid == 1: unk_count += 1 | |
| if pid != prev: | |
| collapsed.append(id2phoneme.get(int(pid), "<unk>")) | |
| prev = pid | |
| # 5. Transcription handling | |
| trans = sample.get("transcription_normalised") or sample.get("transcription") or "" | |
| trans = str(trans) | |
| if not trans.strip(): continue | |
| target_phonemes = g2p_manager.convert_sentence(trans) | |
| # 6. PER | |
| per = calculate_per(target_phonemes, collapsed) | |
| total_per += per | |
| count += 1 | |
| # Sample display | |
| if i < 3 or (i % 20 == 0): | |
| print(f"\n--- Sample {i+1} ---", flush=True) | |
| print(f"Ref: {' '.join(target_phonemes[:20])}...", flush=True) | |
| print(f"Hyp: {' '.join(collapsed[:20])}...", flush=True) | |
| print(f"Stat: {len(pred_ids)} frames, {unk_count} <unk> frames.", flush=True) | |
| print(f"PER: {per:.2%}", flush=True) | |
| elif (i+1) % 5 == 0: | |
| print(f"Processed {i+1}/{args.num_samples}...", end="\r", flush=True) | |
| except StopIteration: | |
| break | |
| except Exception as e: | |
| print(f"Error processing sample {i}: {e}", flush=True) | |
| continue | |
| if count > 0: | |
| avg_per = total_per / count | |
| print(f"\n\n{'='*40}") | |
| print(f"FINAL RESULTS: PER = {avg_per:.2%}") | |
| print(f"{'='*40}") | |
| if avg_per > 0.8: | |
| print("⚠️ WARNING: High PER detected. This usually indicates under-training or a vocab mismatch.") | |
| print("👉 Try running with --sanity_check to see performance on training data.") | |
| else: | |
| print("\nNo samples were successfully evaluated.") | |
| if __name__ == "__main__": | |
| main() | |