""" Varuna STT — inference example. Usage: pip install nemo_toolkit[asr]>=2.4 omegaconf torch soundfile python inference.py --audio path/to/clip.wav # Programmatic from inference import VarunaSTT model = VarunaSTT() print(model.transcribe(["a.wav", "b.wav"])) """ from __future__ import annotations import argparse from pathlib import Path import torch from omegaconf import OmegaConf, open_dict from nemo.collections.asr.models import EncDecRNNTBPEModel # ── Paths (adjust if you move the files) ────────────────────────────────────── HERE = Path(__file__).resolve().parent NEMOTRON_BASE = HERE / "nemotron-speech-streaming-en-0.6b.nemo" TOKENIZER_DIR = HERE # contains tokenizer.model, vocab.txt CKPT_PATH = HERE / "varuna.ckpt" class VarunaSTT: def __init__(self, device: str | None = None, base: Path = NEMOTRON_BASE, ckpt: Path = CKPT_PATH, tokenizer_dir: Path = TOKENIZER_DIR): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = EncDecRNNTBPEModel.restore_from(str(base), map_location=self.device) self.model.change_vocabulary(new_tokenizer_dir=str(tokenizer_dir), new_tokenizer_type="bpe") # Greedy-batch RNN-T decoding (deterministic, fast on GPU) decoding_cfg = OmegaConf.to_container(self.model.cfg.decoding, resolve=True) decoding_cfg = OmegaConf.create(decoding_cfg) with open_dict(decoding_cfg): decoding_cfg.strategy = "greedy_batch" if "greedy" not in decoding_cfg: decoding_cfg.greedy = {} decoding_cfg.greedy.use_cuda_graph_decoder = False self.model.change_decoding_strategy(decoding_cfg) # Load fine-tuned weights state = torch.load(str(ckpt), map_location=self.device, weights_only=False) sd = state["state_dict"] if "state_dict" in state else state self.model.load_state_dict(sd, strict=False) self.model = self.model.to(self.device).eval() @torch.inference_mode() def transcribe(self, audio_paths: list[str], batch_size: int = 8) -> list[str]: """Transcribe audio file(s) at 16 kHz mono. Returns plain Hindi text per clip.""" out = self.model.transcribe(audio=list(audio_paths), batch_size=batch_size, return_hypotheses=False, verbose=False) if isinstance(out, tuple): out = out[0] return [h.text if hasattr(h, "text") else h for h in out] def main(): ap = argparse.ArgumentParser() ap.add_argument("--audio", nargs="+", required=True) ap.add_argument("--batch-size", type=int, default=8) ap.add_argument("--device", default=None) args = ap.parse_args() model = VarunaSTT(device=args.device) for path, hyp in zip(args.audio, model.transcribe(args.audio, args.batch_size)): print(f"[{path}]\n {hyp}") if __name__ == "__main__": main()