File size: 3,187 Bytes
2e67c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Varuna STT — inference example.

Usage:
    pip install nemo_toolkit[asr]>=2.4 omegaconf torch soundfile

    python inference.py --audio path/to/clip.wav

    # Programmatic
    from inference import VarunaSTT
    model = VarunaSTT()
    print(model.transcribe(["a.wav", "b.wav"]))
"""
from __future__ import annotations

import argparse
from pathlib import Path

import torch
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.models import EncDecRNNTBPEModel

# ── Paths (adjust if you move the files) ──────────────────────────────────────
HERE = Path(__file__).resolve().parent
NEMOTRON_BASE = HERE / "nemotron-speech-streaming-en-0.6b.nemo"
TOKENIZER_DIR = HERE                    # contains tokenizer.model, vocab.txt
CKPT_PATH     = HERE / "varuna.ckpt"


class VarunaSTT:
    def __init__(self, device: str | None = None,
                 base: Path = NEMOTRON_BASE,
                 ckpt: Path = CKPT_PATH,
                 tokenizer_dir: Path = TOKENIZER_DIR):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = EncDecRNNTBPEModel.restore_from(str(base), map_location=self.device)
        self.model.change_vocabulary(new_tokenizer_dir=str(tokenizer_dir),
                                      new_tokenizer_type="bpe")

        # Greedy-batch RNN-T decoding (deterministic, fast on GPU)
        decoding_cfg = OmegaConf.to_container(self.model.cfg.decoding, resolve=True)
        decoding_cfg = OmegaConf.create(decoding_cfg)
        with open_dict(decoding_cfg):
            decoding_cfg.strategy = "greedy_batch"
            if "greedy" not in decoding_cfg:
                decoding_cfg.greedy = {}
            decoding_cfg.greedy.use_cuda_graph_decoder = False
        self.model.change_decoding_strategy(decoding_cfg)

        # Load fine-tuned weights
        state = torch.load(str(ckpt), map_location=self.device, weights_only=False)
        sd = state["state_dict"] if "state_dict" in state else state
        self.model.load_state_dict(sd, strict=False)
        self.model = self.model.to(self.device).eval()

    @torch.inference_mode()
    def transcribe(self, audio_paths: list[str], batch_size: int = 8) -> list[str]:
        """Transcribe audio file(s) at 16 kHz mono. Returns plain Hindi text per clip."""
        out = self.model.transcribe(audio=list(audio_paths),
                                    batch_size=batch_size,
                                    return_hypotheses=False,
                                    verbose=False)
        if isinstance(out, tuple):
            out = out[0]
        return [h.text if hasattr(h, "text") else h for h in out]


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--audio", nargs="+", required=True)
    ap.add_argument("--batch-size", type=int, default=8)
    ap.add_argument("--device", default=None)
    args = ap.parse_args()

    model = VarunaSTT(device=args.device)
    for path, hyp in zip(args.audio, model.transcribe(args.audio, args.batch_size)):
        print(f"[{path}]\n  {hyp}")


if __name__ == "__main__":
    main()