| import gradio as gr |
| import os |
| import shutil |
| import tarfile |
| import torch |
| import soundfile as sf |
| import torchaudio.functional as F |
| from huggingface_hub import snapshot_download |
| from omegaconf import OmegaConf |
| from nemo.collections.asr.models import ASRModel |
|
|
| |
| MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct" |
|
|
| print("1. Downloading unzipped model files from Hugging Face...") |
| model_dir = snapshot_download(repo_id=MODEL_NAME) |
|
|
| print("2. Patching the configuration...") |
| config_path = os.path.join(model_dir, "model_config.yaml") |
| config = OmegaConf.load(config_path) |
|
|
| |
| OmegaConf.set_struct(config, False) |
|
|
| |
| if 'encoder' in config: |
| config.encoder.pop('use_pytorch_sdpa', None) |
| config.encoder.pop('use_pytorch_sdpa_backends', None) |
|
|
| |
| if 'decoding' in config: |
| |
| if config.decoding.get('strategy') == 'greedy_batch': |
| print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'") |
| config.decoding.strategy = 'greedy' |
| |
| |
| if 'confidence_cfg' in config.decoding and config.decoding.confidence_cfg is not None: |
| if 'tdt_include_duration' in config.decoding.confidence_cfg: |
| print(" -> Removing 'tdt_include_duration' from confidence_cfg") |
| config.decoding.confidence_cfg.pop('tdt_include_duration', None) |
|
|
| print("3. Packaging files into a standard .nemo archive...") |
| patched_dir = "patched_nemo_env" |
| os.makedirs(patched_dir, exist_ok=True) |
|
|
| |
| for item in os.listdir(model_dir): |
| s = os.path.join(model_dir, item) |
| d = os.path.join(patched_dir, item) |
| if os.path.isfile(s): |
| shutil.copy2(s, d) |
|
|
| |
| OmegaConf.save(config, os.path.join(patched_dir, "model_config.yaml")) |
|
|
| |
| nemo_filepath = "patched_model.nemo" |
| with tarfile.open(nemo_filepath, "w") as tar: |
| for item in os.listdir(patched_dir): |
| tar.add(os.path.join(patched_dir, item), arcname=item) |
|
|
| print("4. Restoring model from patched .nemo file (this may take a moment)...") |
| model = ASRModel.restore_from(restore_path=nemo_filepath) |
| model.eval() |
| print("Model loaded successfully!") |
|
|
| def transcribe(audio_path): |
| if not audio_path: |
| return "Please upload or record audio." |
| |
| try: |
| |
| data, sample_rate = sf.read(audio_path) |
| waveform = torch.from_numpy(data).float() |
| |
| |
| if waveform.ndim == 1: |
| waveform = waveform.unsqueeze(0) |
| else: |
| waveform = waveform.transpose(0, 1) |
| |
| |
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
| |
| |
| if sample_rate != 16000: |
| waveform = F.resample(waveform, sample_rate, 16000) |
| sample_rate = 16000 |
| |
| |
| processed_path = audio_path + "_mono_16k.wav" |
| |
| flat_numpy_waveform = waveform.squeeze(0).numpy() |
| sf.write(processed_path, flat_numpy_waveform, 16000) |
| |
| |
| transcription = model.transcribe(paths2audio_files=[processed_path])[0] |
| |
| if isinstance(transcription, list): |
| return transcription[0] |
| return transcription |
| |
| except Exception as e: |
| return f"Error during transcription: {str(e)}" |
|
|
| |
| with gr.Blocks(title="Malayalam FastConformer ASR") as demo: |
| gr.Markdown("# 🎙️ Malayalam FastConformer Speech-to-Text") |
| gr.Markdown("Upload an audio file or record from your microphone to generate a transcription.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| audio_input = gr.Audio(type="filepath", label="Input Audio") |
| transcribe_btn = gr.Button("Transcribe", variant="primary") |
| |
| with gr.Column(): |
| text_output = gr.Textbox(label="Transcription", lines=5) |
| |
| transcribe_btn.click( |
| fn=transcribe, |