varuna-stt / inference.py
harsh2ai's picture
initial release: Varuna STT
2e67c80 verified
"""
Varuna STT — inference example.
Usage:
pip install nemo_toolkit[asr]>=2.4 omegaconf torch soundfile
python inference.py --audio path/to/clip.wav
# Programmatic
from inference import VarunaSTT
model = VarunaSTT()
print(model.transcribe(["a.wav", "b.wav"]))
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from omegaconf import OmegaConf, open_dict
from nemo.collections.asr.models import EncDecRNNTBPEModel
# ── Paths (adjust if you move the files) ──────────────────────────────────────
HERE = Path(__file__).resolve().parent
NEMOTRON_BASE = HERE / "nemotron-speech-streaming-en-0.6b.nemo"
TOKENIZER_DIR = HERE # contains tokenizer.model, vocab.txt
CKPT_PATH = HERE / "varuna.ckpt"
class VarunaSTT:
def __init__(self, device: str | None = None,
base: Path = NEMOTRON_BASE,
ckpt: Path = CKPT_PATH,
tokenizer_dir: Path = TOKENIZER_DIR):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = EncDecRNNTBPEModel.restore_from(str(base), map_location=self.device)
self.model.change_vocabulary(new_tokenizer_dir=str(tokenizer_dir),
new_tokenizer_type="bpe")
# Greedy-batch RNN-T decoding (deterministic, fast on GPU)
decoding_cfg = OmegaConf.to_container(self.model.cfg.decoding, resolve=True)
decoding_cfg = OmegaConf.create(decoding_cfg)
with open_dict(decoding_cfg):
decoding_cfg.strategy = "greedy_batch"
if "greedy" not in decoding_cfg:
decoding_cfg.greedy = {}
decoding_cfg.greedy.use_cuda_graph_decoder = False
self.model.change_decoding_strategy(decoding_cfg)
# Load fine-tuned weights
state = torch.load(str(ckpt), map_location=self.device, weights_only=False)
sd = state["state_dict"] if "state_dict" in state else state
self.model.load_state_dict(sd, strict=False)
self.model = self.model.to(self.device).eval()
@torch.inference_mode()
def transcribe(self, audio_paths: list[str], batch_size: int = 8) -> list[str]:
"""Transcribe audio file(s) at 16 kHz mono. Returns plain Hindi text per clip."""
out = self.model.transcribe(audio=list(audio_paths),
batch_size=batch_size,
return_hypotheses=False,
verbose=False)
if isinstance(out, tuple):
out = out[0]
return [h.text if hasattr(h, "text") else h for h in out]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--audio", nargs="+", required=True)
ap.add_argument("--batch-size", type=int, default=8)
ap.add_argument("--device", default=None)
args = ap.parse_args()
model = VarunaSTT(device=args.device)
for path, hyp in zip(args.audio, model.transcribe(args.audio, args.batch_size)):
print(f"[{path}]\n {hyp}")
if __name__ == "__main__":
main()