ASR / src /inference /cli_test_menu.py
MihirRPatil's picture
deploy: CDAC ASR backend with pitch/stress fix and LLM feedback
88a679b
Raw
History Blame Contribute Delete
3.79 kB
import os
import sys
import torch
import json
import sounddevice as sd
import numpy as np
from transformers import Wav2Vec2Processor
from src.models.phoneme_embedder import Wav2Vec2PhonemeEmbedder
from src.utils.audio_utils import AudioPreprocessor
def clear_screen():
os.system('cls' if os.name == 'nt' else 'clear')
class LiveTester:
def __init__(self, model_dir):
print(f"Initializing model from {model_dir}...")
self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
self.model = Wav2Vec2PhonemeEmbedder.from_pretrained(model_dir)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
vocab_path = os.path.join(model_dir, "vocab.json")
with open(vocab_path, "r", encoding="utf8") as f:
self.vocab = json.load(f)
self.id2phoneme = {v: k for k, v in self.vocab.items()}
self.pad_id = self.processor.tokenizer.pad_token_id
self.audio_prep = AudioPreprocessor(sr=16000)
self.sr = 16000
def record_and_transcribe(self, duration=3.0):
print(f"\n🎤 Recording for {duration} seconds... Speak now!")
recording = sd.rec(int(duration * self.sr), samplerate=self.sr, channels=1)
sd.wait()
print("✅ Recording complete. Processing...")
audio = recording.squeeze()
# Clean audio with VAD/FFT
audio = self.audio_prep.preprocess(audio)
inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(self.device)
with torch.no_grad():
logits = self.model(input_values).logits
pred_ids = torch.argmax(logits, dim=-1)[0].cpu().numpy()
collapsed = []
prev = None
for pid in pred_ids:
if pid != prev and pid != self.pad_id:
collapsed.append(self.id2phoneme.get(pid, "<unk>"))
prev = pid
return collapsed
def main():
model_dir = "trained_models/20k_steps"
if not os.path.exists(model_dir):
print(f"Error: {model_dir} not found. Adjust the path in cli_test_menu.py")
return
tester = LiveTester(model_dir)
while True:
clear_screen()
print("="*50)
print("🎙️ PHONEME EMBEDDER LIVE TEST MENU")
print("="*50)
print(f"Model: {model_dir}")
print("-"*50)
print("1. Quick Test (3 seconds)")
print("2. Long Test (5 seconds)")
print("3. Custom Duration")
print("4. Exit")
print("-"*50)
choice = input("Select an option: ")
if choice == '1':
result = tester.record_and_transcribe(3.0)
print("\nPREDICTED PHONEMES:")
print(" ".join(result))
input("\nPress Enter to continue...")
elif choice == '2':
result = tester.record_and_transcribe(5.0)
print("\nPREDICTED PHONEMES:")
print(" ".join(result))
input("\nPress Enter to continue...")
elif choice == '3':
try:
dur = float(input("Enter duration in seconds: "))
result = tester.record_and_transcribe(dur)
print("\nPREDICTED PHONEMES:")
print(" ".join(result))
except ValueError:
print("Invalid duration.")
input("\nPress Enter to continue...")
elif choice == '4':
print("Goodbye!")
break
else:
print("Invalid choice.")
input("\nPress Enter to continue...")
if __name__ == "__main__":
main()