| import gradio as gr |
| import torch |
| import pickle |
| import os |
| from pydub import AudioSegment |
| from tempfile import NamedTemporaryFile |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel |
| import nemo.collections.asr as nemo_asr |
| from omegaconf import OmegaConf |
| from pathlib import Path |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent |
| TRIE_PICKLE_PATH = PROJECT_ROOT / "model/nepali_words_trie.pkl" |
| |
| GPT2_MODEL_NAME = "nabin2004/nepali_GPT2" |
|
|
| |
| class TrieNode: |
| def __init__(self): |
| self.children = {} |
| self.is_end_of_word = False |
|
|
| class Trie: |
| def __init__(self): |
| self.root = TrieNode() |
|
|
| def insert(self, word): |
| node = self.root |
| for char in word: |
| if char not in node.children: |
| node.children[char] = TrieNode() |
| node = node.children[char] |
| node.is_end_of_word = True |
|
|
| def is_word_spelled_correctly(self, word): |
| node = self.root |
| for char in word: |
| if char not in node.children: |
| return False |
| node = node.children[char] |
| return node.is_end_of_word |
|
|
| def suggest_words(self, prefix, max_suggestions=10): |
| suggestions = [] |
| node = self.root |
| for char in prefix: |
| if char in node.children: |
| node = node.children[char] |
| else: |
| return [] |
|
|
| def collect_words(node, current_word): |
| if len(suggestions) >= max_suggestions: |
| return |
| if node.is_end_of_word: |
| suggestions.append(current_word) |
| for char, child in node.children.items(): |
| collect_words(child, current_word + char) |
|
|
| collect_words(node, prefix) |
| return suggestions |
|
|
| |
| def load_trie(): |
| class TrieUnpickler(pickle.Unpickler): |
| def find_class(self, module, name): |
| if name == "Trie": |
| return Trie |
| elif name == "TrieNode": |
| return TrieNode |
| return super().find_class(module, name) |
|
|
| with open(TRIE_PICKLE_PATH, "rb") as f: |
| trie = TrieUnpickler(f).load() |
| return trie |
|
|
| trie = load_trie() |
|
|
| |
| print("π Loading ASR model...") |
| asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(model_name="nabin2004/finetuned_Indic_Conformer_nepali") |
| asr_model.eval() |
| asr_model.to('cuda' if torch.cuda.is_available() else 'cpu') |
| asr_model.change_decoding_strategy(OmegaConf.create({"strategy": "greedy", "nbest": 5})) |
| asr_model.cur_decoder = "ctc" |
| print("β
ASR model loaded.") |
|
|
| print("π Loading GPT-2 model...") |
| tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL_NAME) |
| gpt2_model = GPT2LMHeadModel.from_pretrained(GPT2_MODEL_NAME) |
| gpt2_model.eval() |
| gpt2_model.to('cuda' if torch.cuda.is_available() else 'cpu') |
| print("β
GPT-2 loaded.") |
|
|
| |
| def score_with_gpt2(text): |
| input_ids = tokenizer.encode(text, return_tensors='pt').to(gpt2_model.device) |
| with torch.no_grad(): |
| output = gpt2_model(input_ids, labels=input_ids) |
| return -output.loss.item() |
|
|
| def rank_with_gpt2(suggestions, context): |
| if not suggestions: |
| return "" |
| best_word = suggestions[0] |
| best_score = float("-inf") |
| for word in suggestions: |
| score = score_with_gpt2(' '.join(context + [word])) |
| if score > best_score: |
| best_score = score |
| best_word = word |
| return best_word |
|
|
| def correct_tokens(tokens): |
| corrected_tokens = [] |
| for token in tokens: |
| if trie.is_word_spelled_correctly(token): |
| corrected_tokens.append(token) |
| else: |
| prefix = token[:2] |
| suggestions = trie.suggest_words(prefix) |
| best = rank_with_gpt2(suggestions, corrected_tokens) |
| corrected_tokens.append(best) |
| return ' '.join(corrected_tokens) |
|
|
| |
| def transcribe(audio_path): |
| if not audio_path: |
| return "β No audio found." |
| |
| audio = AudioSegment.from_file(audio_path) |
| audio = audio.set_channels(1) |
|
|
| with NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: |
| audio.export(tmp_file.name, format="wav") |
| processed_path = tmp_file.name |
|
|
| result = asr_model.transcribe([processed_path], language_id="ne")[0][0] |
| os.remove(processed_path) |
| tokens = result.split() |
| return correct_tokens(tokens) |
|
|
| |
| with gr.Blocks(title="Nepali ASR & Spell Checker") as demo: |
| gr.Markdown("## π³π΅ Nepali Speech-to-Text with Spell Correction") |
|
|
| with gr.Row(): |
| audio_upload = gr.Audio(sources=["microphone", "upload"], type="filepath", label="π€ Record or Upload Audio") |
| output_text = gr.Textbox(label="π Corrected Transcription", lines=8, placeholder="Output will appear here...") |
|
|
| transcribe_btn = gr.Button("π Transcribe") |
| clear_btn = gr.Button("π§Ή Clear") |
|
|
| transcribe_btn.click(fn=transcribe, inputs=[audio_upload], outputs=output_text) |
| clear_btn.click(lambda: (None, ""), outputs=[audio_upload, output_text]) |
|
|
| demo.launch() |
|
|