AvtnshM commited on
Commit
6d70b1d
·
verified ·
1 Parent(s): c511528
Files changed (1) hide show
  1. app.py +81 -94
app.py CHANGED
@@ -1,104 +1,91 @@
1
  import time
2
- import librosa
3
  import gradio as gr
4
- from transformers import AutoModelForCTC, AutoProcessor, pipeline
5
- from jiwer import wer, cer
6
-
7
- # ---------------------------
8
- # Load Models (CPU only)
9
- # ---------------------------
10
-
11
- # 1. IndicConformer
12
- indic_model_id = "ai4bharat/indic-conformer-600m-multilingual"
13
- indic_processor = AutoProcessor.from_pretrained(indic_model_id)
14
- indic_model = AutoModelForCTC.from_pretrained(indic_model_id)
15
- indic_pipeline = pipeline(
16
- "automatic-speech-recognition",
17
- model=indic_model,
18
- tokenizer=indic_processor.tokenizer,
19
- feature_extractor=indic_processor.feature_extractor,
20
- device=-1 # CPU
21
  )
 
22
 
23
- # 2. Facebook MMS (generic multilingual ASR)
24
- mms_model_id = "facebook/mms-1b-all"
25
- mms_processor = AutoProcessor.from_pretrained(mms_model_id)
26
- mms_model = AutoModelForCTC.from_pretrained(mms_model_id)
27
- mms_pipeline = pipeline(
28
- "automatic-speech-recognition",
29
- model=mms_model,
30
- tokenizer=mms_processor.tokenizer,
31
- feature_extractor=mms_processor.feature_extractor,
32
- device=-1
33
- )
34
-
35
- # 3. Jivi AudioX (North example)
36
- jivi_model_id = "jiviai/audioX-north-v1"
37
- jivi_pipeline = pipeline(
38
- "automatic-speech-recognition",
39
- model=jivi_model_id,
40
- device=-1
41
- )
42
-
43
- # ---------------------------
44
- # Utility Functions
45
- # ---------------------------
46
-
47
- def evaluate_model(pipeline_fn, audio_path, reference_text):
48
- # Load audio (resample to 16kHz for consistency)
49
- speech, sr = librosa.load(audio_path, sr=16000)
50
-
51
- # Measure runtime
52
- start = time.time()
53
- result = pipeline_fn(speech)
54
- end = time.time()
55
-
56
- # Extract transcription
57
- hypothesis = result["text"]
58
-
59
- # Compute metrics
60
- word_error = wer(reference_text.lower(), hypothesis.lower())
61
- char_error = cer(reference_text.lower(), hypothesis.lower())
62
- rtf = (end - start) / (len(speech) / sr) # real-time factor
63
-
64
- return hypothesis, word_error, char_error, rtf
65
-
66
- def compare_models(audio, reference_text, lang="hi"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  results = {}
 
 
 
68
 
69
- # IndicConformer
70
- hyp, w, c, r = evaluate_model(indic_pipeline, audio, reference_text)
71
- results["IndicConformer"] = (hyp, w, c, r)
72
-
73
- # MMS
74
- hyp, w, c, r = evaluate_model(mms_pipeline, audio, reference_text)
75
- results["MMS"] = (hyp, w, c, r)
76
-
77
- # Jivi
78
- hyp, w, c, r = evaluate_model(jivi_pipeline, audio, reference_text)
79
- results["Jivi"] = (hyp, w, c, r)
80
-
81
- # Build results table
82
- table = "| Model | Transcription | WER | CER | RTF |\n"
83
- table += "|-------|---------------|-----|-----|-----|\n"
84
- for model, (hyp, w, c, r) in results.items():
85
- table += f"| {model} | {hyp} | {w:.3f} | {c:.3f} | {r:.3f} |\n"
86
-
87
- return table
88
-
89
- # ---------------------------
90
- # Gradio UI
91
- # ---------------------------
92
  demo = gr.Interface(
93
- fn=compare_models,
94
- inputs=[
95
- gr.Audio(type="filepath", label="Upload Audio (≤20s recommended)"),
96
- gr.Textbox(label="Reference Text"),
97
- gr.Dropdown(choices=["hi", "gu", "ta"], value="hi", label="Language")
98
- ],
99
- outputs=gr.Markdown(label="Results"),
100
- title="ASR Benchmark (CPU mode): IndicConformer vs MMS vs Jivi",
101
- description="Runs on free CPU Spaces. Upload short audio and reference text. Compares models on WER, CER, and RTF."
102
  )
103
 
104
  if __name__ == "__main__":
 
1
  import time
2
+ import torch
3
  import gradio as gr
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoProcessor,
7
+ AutoModelForCTC,
8
+ WhisperProcessor,
9
+ WhisperForConditionalGeneration,
10
+ pipeline,
 
 
 
 
 
 
 
 
 
 
11
  )
12
+ from jiwer import wer, cer
13
 
14
+ # -----------------------------
15
+ # Load sample dataset (Hindi)
16
+ # -----------------------------
17
+ # We’ll use a few samples for faster CPU benchmarking
18
+ test_ds = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test[:3]")
19
+
20
+ # -----------------------------
21
+ # Model configs
22
+ # -----------------------------
23
+ models = {
24
+ "IndicWhisper (Hindi)": {
25
+ "id": "ai4bharat/indicwhisper-large-hi",
26
+ "type": "whisper",
27
+ },
28
+ "IndicConformer": {
29
+ "id": "ai4bharat/indic-conformer-600m-multilingual",
30
+ "type": "conformer",
31
+ },
32
+ "MMS (Facebook)": {
33
+ "id": "facebook/mms-1b-all",
34
+ "type": "conformer",
35
+ },
36
+ }
37
+
38
+ # -----------------------------
39
+ # Helper function for inference
40
+ # -----------------------------
41
+ def evaluate_model(name, cfg, dataset):
42
+ print(f"\nRunning {name}...")
43
+ start_time = time.time()
44
+
45
+ if cfg["type"] == "whisper":
46
+ processor = WhisperProcessor.from_pretrained(cfg["id"])
47
+ model = WhisperForConditionalGeneration.from_pretrained(cfg["id"]).to("cpu")
48
+ pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=-1)
49
+
50
+ else: # Conformer (Indic or MMS)
51
+ processor = AutoProcessor.from_pretrained(cfg["id"], trust_remote_code=True)
52
+ model = AutoModelForCTC.from_pretrained(cfg["id"], trust_remote_code=True).to("cpu")
53
+ pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=-1)
54
+
55
+ preds, refs = [], []
56
+ for sample in dataset:
57
+ audio = sample["audio"]["array"]
58
+ ref_text = sample["sentence"]
59
+ out = pipe(audio)
60
+ preds.append(out["text"])
61
+ refs.append(ref_text)
62
+
63
+ elapsed = time.time() - start_time
64
+ rtf = elapsed / sum(len(s["audio"]["array"]) / 16000 for s in dataset)
65
+
66
+ return {
67
+ "WER": wer(refs, preds),
68
+ "CER": cer(refs, preds),
69
+ "RTF": rtf,
70
+ "Predictions": preds,
71
+ "References": refs,
72
+ }
73
+
74
+ # -----------------------------
75
+ # Gradio UI
76
+ # -----------------------------
77
+ def run_comparison():
78
  results = {}
79
+ for name, cfg in models.items():
80
+ results[name] = evaluate_model(name, cfg, test_ds)
81
+ return results
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  demo = gr.Interface(
84
+ fn=run_comparison,
85
+ inputs=[],
86
+ outputs="json",
87
+ title="Indic ASR Benchmark (CPU)",
88
+ description="Compares IndicWhisper (Hindi), IndicConformer, and MMS on WER, CER, and RTF.",
 
 
 
 
89
  )
90
 
91
  if __name__ == "__main__":