AvtnshM commited on
Commit
61a53f6
·
verified ·
1 Parent(s): ab67705
Files changed (1) hide show
  1. app.py +52 -73
app.py CHANGED
@@ -1,102 +1,81 @@
1
  import time
2
  import torch
3
  import gradio as gr
4
- from datasets import load_dataset
5
  from transformers import (
6
- AutoProcessor,
7
- AutoModelForCTC,
8
- WhisperProcessor,
9
- WhisperForConditionalGeneration,
10
- pipeline,
11
  )
12
  from jiwer import wer, cer
13
 
14
- # -----------------------------
15
- # Load sample dataset (Hindi, Common Voice 17.0)
16
- # -----------------------------
17
- # Use just 3 samples for faster CPU benchmarking
18
- test_ds = load_dataset("mozilla-foundation/common_voice_17_0", "hi", split="test[:3]")
 
19
 
20
- # -----------------------------
21
- # Model configs
22
- # -----------------------------
23
- models = {
24
- "IndicWhisper (Hindi)": {
25
- "id": "ai4bharat/indicwhisper-large-hi",
26
- "type": "whisper",
27
- },
28
- "IndicConformer": {
29
- "id": "ai4bharat/indic-conformer-600m-multilingual",
30
- "type": "conformer",
31
- },
32
- "MMS (Facebook)": {
33
- "id": "facebook/mms-1b-all",
34
- "type": "conformer",
35
- },
36
- }
37
-
38
- # -----------------------------
39
- # Helper function for inference
40
- # -----------------------------
41
- def evaluate_model(name, cfg, dataset):
42
- print(f"\nRunning {name}...")
43
- start_time = time.time()
44
 
45
  if cfg["type"] == "whisper":
46
- processor = WhisperProcessor.from_pretrained(cfg["id"])
47
- model = WhisperForConditionalGeneration.from_pretrained(cfg["id"]).to("cpu")
48
  pipe = pipeline(
49
  "automatic-speech-recognition",
50
  model=model,
51
- tokenizer=processor.tokenizer,
52
- feature_extractor=processor.feature_extractor,
53
- device=-1,
54
  )
55
- else: # Conformer (Indic or MMS)
56
- processor = AutoProcessor.from_pretrained(cfg["id"], trust_remote_code=True)
57
- model = AutoModelForCTC.from_pretrained(cfg["id"], trust_remote_code=True).to("cpu")
58
  pipe = pipeline(
59
  "automatic-speech-recognition",
60
  model=model,
61
- tokenizer=processor.tokenizer,
62
- feature_extractor=processor.feature_extractor,
63
- device=-1,
64
  )
65
 
66
- preds, refs = [], []
67
- for sample in dataset:
68
- audio = sample["audio"]["array"]
69
- ref_text = sample["sentence"]
70
- out = pipe(audio)
71
- preds.append(out["text"])
72
- refs.append(ref_text)
73
 
74
- elapsed = time.time() - start_time
75
- rtf = elapsed / sum(len(s["audio"]["array"]) / 16000 for s in dataset)
76
 
77
- return {
78
- "WER": wer(refs, preds),
79
- "CER": cer(refs, preds),
80
- "RTF": rtf,
81
- "Predictions": preds,
82
- "References": refs,
83
- }
84
 
85
- # -----------------------------
86
- # Gradio UI
87
- # -----------------------------
88
- def run_comparison():
89
  results = {}
90
- for name, cfg in models.items():
91
- results[name] = evaluate_model(name, cfg, test_ds)
 
 
 
92
  return results
93
 
94
  demo = gr.Interface(
95
- fn=run_comparison,
96
- inputs=[],
97
- outputs="json",
98
- title="Indic ASR Benchmark (CPU)",
99
- description="Compares IndicWhisper (Hindi), IndicConformer, and MMS on WER, CER, and RTF.",
 
 
 
 
100
  )
101
 
102
  if __name__ == "__main__":
 
1
  import time
2
  import torch
3
  import gradio as gr
4
+ import torchaudio
5
  from transformers import (
6
+ WhisperProcessor, WhisperForConditionalGeneration,
7
+ AutoProcessor, AutoModelForCTC, pipeline
 
 
 
8
  )
9
  from jiwer import wer, cer
10
 
11
+ # Utility to load audio and resample to 16 kHz
12
+ def load_audio(fp):
13
+ waveform, sr = torchaudio.load(fp)
14
+ if sr != 16000:
15
+ waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
16
+ return waveform.squeeze(0), 16000
17
 
18
+ # Evaluation function
19
+ def eval_model(name, cfg, file, ref):
20
+ waveform, sr = load_audio(file)
21
+ start = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  if cfg["type"] == "whisper":
24
+ proc = WhisperProcessor.from_pretrained(cfg["id"])
25
+ model = WhisperForConditionalGeneration.from_pretrained(cfg["id"])
26
  pipe = pipeline(
27
  "automatic-speech-recognition",
28
  model=model,
29
+ tokenizer=proc.tokenizer,
30
+ feature_extractor=proc.feature_extractor,
31
+ device=-1
32
  )
33
+ else:
34
+ proc = AutoProcessor.from_pretrained(cfg["id"], trust_remote_code=True)
35
+ model = AutoModelForCTC.from_pretrained(cfg["id"], trust_remote_code=True)
36
  pipe = pipeline(
37
  "automatic-speech-recognition",
38
  model=model,
39
+ tokenizer=proc.tokenizer,
40
+ feature_extractor=proc.feature_extractor,
41
+ device=-1
42
  )
43
 
44
+ result = pipe(waveform)
45
+ hyp = result["text"].lower()
46
+ w = wer(ref.lower() if ref else "", hyp) if ref else None
47
+ c = cer(ref.lower() if ref else "", hyp) if ref else None
48
+ rtf = (time.time() - start) / (waveform.shape[0] / sr)
 
 
49
 
50
+ return {"Transcription": hyp, "WER": w, "CER": c, "RTF": rtf}
 
51
 
52
+ # Model configs
53
+ MODELS = {
54
+ "IndicConformer (AI4Bharat)": {"id": "ai4bharat/indic-conformer-600m-multilingual", "type": "conformer"},
55
+ "AudioX-North (Jivi AI)": {"id": "jiviai/audioX-north-v1", "type": "whisper"},
56
+ "MMS (Facebook)": {"id": "facebook/mms-1b-all", "type": "conformer"},
57
+ }
 
58
 
59
+ # Gradio interface logic
60
+ def compare_all(audio, reference, language):
 
 
61
  results = {}
62
+ for name, cfg in MODELS.items():
63
+ try:
64
+ results[name] = eval_model(name, cfg, audio, reference)
65
+ except Exception as e:
66
+ results[name] = {"Error": str(e)}
67
  return results
68
 
69
  demo = gr.Interface(
70
+ fn=compare_all,
71
+ inputs=[
72
+ gr.Audio(type="filepath", label="Upload Audio (<=20s recommended)"),
73
+ gr.Textbox(label="Reference Transcript (optional)"),
74
+ gr.Dropdown(choices=["hi","gu","ta"], label="Language", value="hi")
75
+ ],
76
+ outputs=gr.JSON(label="Benchmark Results"),
77
+ title="Indic ASR Benchmark (CPU-only)",
78
+ description="Compare IndicConformer, AudioX-North, and MMS on WER, CER, and RTF."
79
  )
80
 
81
  if __name__ == "__main__":