nabin2004's picture
Upload folder using huggingface_hub
0affa8b verified
import gradio as gr
import torch
import pickle
import os
from pydub import AudioSegment
from tempfile import NamedTemporaryFile
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf
from pathlib import Path
# ------------------ Paths ------------------ #
PROJECT_ROOT = Path(__file__).resolve().parent
TRIE_PICKLE_PATH = PROJECT_ROOT / "model/nepali_words_trie.pkl"
# ASR_MODEL_PATH = PROJECT_ROOT / "model/indicconformer_stt_ne_hybrid_rnnt_large.nemo"
GPT2_MODEL_NAME = "nabin2004/nepali_GPT2"
# ------------------ Trie Classes ------------------ #
class TrieNode:
def __init__(self):
self.children = {}
self.is_end_of_word = False
class Trie:
def __init__(self):
self.root = TrieNode()
def insert(self, word):
node = self.root
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end_of_word = True
def is_word_spelled_correctly(self, word):
node = self.root
for char in word:
if char not in node.children:
return False
node = node.children[char]
return node.is_end_of_word
def suggest_words(self, prefix, max_suggestions=10):
suggestions = []
node = self.root
for char in prefix:
if char in node.children:
node = node.children[char]
else:
return []
def collect_words(node, current_word):
if len(suggestions) >= max_suggestions:
return
if node.is_end_of_word:
suggestions.append(current_word)
for char, child in node.children.items():
collect_words(child, current_word + char)
collect_words(node, prefix)
return suggestions
# ------------------ Load Trie ------------------ #
def load_trie():
class TrieUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == "Trie":
return Trie
elif name == "TrieNode":
return TrieNode
return super().find_class(module, name)
with open(TRIE_PICKLE_PATH, "rb") as f:
trie = TrieUnpickler(f).load()
return trie
trie = load_trie()
# ------------------ Load Models ------------------ #
print("πŸ”„ Loading ASR model...")
asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(model_name="nabin2004/finetuned_Indic_Conformer_nepali")
asr_model.eval()
asr_model.to('cuda' if torch.cuda.is_available() else 'cpu')
asr_model.change_decoding_strategy(OmegaConf.create({"strategy": "greedy", "nbest": 5}))
asr_model.cur_decoder = "ctc"
print("βœ… ASR model loaded.")
print("πŸ”„ Loading GPT-2 model...")
tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL_NAME)
gpt2_model = GPT2LMHeadModel.from_pretrained(GPT2_MODEL_NAME)
gpt2_model.eval()
gpt2_model.to('cuda' if torch.cuda.is_available() else 'cpu')
print("βœ… GPT-2 loaded.")
# ------------------ GPT-2 Spell Ranking ------------------ #
def score_with_gpt2(text):
input_ids = tokenizer.encode(text, return_tensors='pt').to(gpt2_model.device)
with torch.no_grad():
output = gpt2_model(input_ids, labels=input_ids)
return -output.loss.item()
def rank_with_gpt2(suggestions, context):
if not suggestions:
return ""
best_word = suggestions[0]
best_score = float("-inf")
for word in suggestions:
score = score_with_gpt2(' '.join(context + [word]))
if score > best_score:
best_score = score
best_word = word
return best_word
def correct_tokens(tokens):
corrected_tokens = []
for token in tokens:
if trie.is_word_spelled_correctly(token):
corrected_tokens.append(token)
else:
prefix = token[:2]
suggestions = trie.suggest_words(prefix)
best = rank_with_gpt2(suggestions, corrected_tokens)
corrected_tokens.append(best)
return ' '.join(corrected_tokens)
# ------------------ Audio Transcription ------------------ #
def transcribe(audio_path):
if not audio_path:
return "❌ No audio found."
audio = AudioSegment.from_file(audio_path)
audio = audio.set_channels(1)
with NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
audio.export(tmp_file.name, format="wav")
processed_path = tmp_file.name
result = asr_model.transcribe([processed_path], language_id="ne")[0][0]
os.remove(processed_path)
tokens = result.split()
return correct_tokens(tokens)
# ------------------ Gradio UI ------------------ #
with gr.Blocks(title="Nepali ASR & Spell Checker") as demo:
gr.Markdown("## πŸ‡³πŸ‡΅ Nepali Speech-to-Text with Spell Correction")
with gr.Row():
audio_upload = gr.Audio(sources=["microphone", "upload"], type="filepath", label="🎀 Record or Upload Audio")
output_text = gr.Textbox(label="πŸ“ Corrected Transcription", lines=8, placeholder="Output will appear here...")
transcribe_btn = gr.Button("πŸš€ Transcribe")
clear_btn = gr.Button("🧹 Clear")
transcribe_btn.click(fn=transcribe, inputs=[audio_upload], outputs=output_text)
clear_btn.click(lambda: (None, ""), outputs=[audio_upload, output_text])
demo.launch()