File size: 4,765 Bytes
14c0892
 
b02797a
 
cd72718
4e887cc
cd72718
c956577
 
14c0892
 
e09c87a
e3a7bb9
14c0892
c956577
 
 
b02797a
c956577
 
 
4fbb118
c956577
4fbb118
 
c956577
 
 
 
e09c87a
4fbb118
e09c87a
4fbb118
 
 
e09c87a
 
 
 
 
 
4fbb118
b02797a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c0892
 
 
 
 
 
 
 
4e887cc
 
 
cd72718
4e887cc
 
 
 
 
 
 
cd72718
 
 
4e887cc
cd72718
 
 
 
4e887cc
cd72718
4e887cc
 
 
cd72718
4e887cc
cd72718
14c0892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e887cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import os
import shutil
import tarfile
import torch
import soundfile as sf
import torchaudio.functional as F
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from nemo.collections.asr.models import ASRModel

# Pointing to your cloned model repository
MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"

print("1. Downloading unzipped model files from Hugging Face...")
model_dir = snapshot_download(repo_id=MODEL_NAME)

print("2. Patching the configuration...")
config_path = os.path.join(model_dir, "model_config.yaml")
config = OmegaConf.load(config_path)

# Allow modifications to the config object
OmegaConf.set_struct(config, False)

# Patch 1: Remove PyTorch 2.0 SDPA keys for NeMo 1.23 compatibility
if 'encoder' in config:
    config.encoder.pop('use_pytorch_sdpa', None)
    config.encoder.pop('use_pytorch_sdpa_backends', None)

# Patch 2 & 3: Fix decoding strategy and confidence config
if 'decoding' in config:
    # Patch 2: Downgrade 'greedy_batch' strategy to 'greedy'
    if config.decoding.get('strategy') == 'greedy_batch':
        print("   -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'")
        config.decoding.strategy = 'greedy'
        
    # Patch 3: Remove 'tdt_include_duration' which NeMo 1.23 doesn't recognize
    if 'confidence_cfg' in config.decoding and config.decoding.confidence_cfg is not None:
        if 'tdt_include_duration' in config.decoding.confidence_cfg:
            print("   -> Removing 'tdt_include_duration' from confidence_cfg")
            config.decoding.confidence_cfg.pop('tdt_include_duration', None)

print("3. Packaging files into a standard .nemo archive...")
patched_dir = "patched_nemo_env"
os.makedirs(patched_dir, exist_ok=True)

# Copy all the raw model files from the HF cache to our working directory
for item in os.listdir(model_dir):
    s = os.path.join(model_dir, item)
    d = os.path.join(patched_dir, item)
    if os.path.isfile(s):
        shutil.copy2(s, d)

# Overwrite the copied config with our cleaned version
OmegaConf.save(config, os.path.join(patched_dir, "model_config.yaml"))

# Tar it up into a standard .nemo file format that NeMo expects
nemo_filepath = "patched_model.nemo"
with tarfile.open(nemo_filepath, "w") as tar:
    for item in os.listdir(patched_dir):
        tar.add(os.path.join(patched_dir, item), arcname=item)

print("4. Restoring model from patched .nemo file (this may take a moment)...")
model = ASRModel.restore_from(restore_path=nemo_filepath)
model.eval()
print("Model loaded successfully!")

def transcribe(audio_path):
    if not audio_path:
        return "Please upload or record audio."
    
    try:
        # 1. Load file using soundfile to completely bypass torchcodec bugs
        data, sample_rate = sf.read(audio_path)
        waveform = torch.from_numpy(data).float()
        
        # 2. Reshape soundfile format [time, channels] to torchaudio format [channels, time]
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)  # Mono: [time] -> [1, time]
        else:
            waveform = waveform.transpose(0, 1)  # Stereo: [time, channels] -> [channels, time]
        
        # 3. Convert to Mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
            
        # 4. Resample to 16000 Hz if necessary
        if sample_rate != 16000:
            waveform = F.resample(waveform, sample_rate, 16000)
            sample_rate = 16000
            
        # 5. Write the file back out using soundfile
        processed_path = audio_path + "_mono_16k.wav"
        # soundfile expects mono arrays to be flat 1D: [time]
        flat_numpy_waveform = waveform.squeeze(0).numpy()
        sf.write(processed_path, flat_numpy_waveform, 16000)
        
        # 6. Pass to NeMo model
        transcription = model.transcribe(paths2audio_files=[processed_path])[0]
        
        if isinstance(transcription, list):
            return transcription[0]
        return transcription
        
    except Exception as e:
        return f"Error during transcription: {str(e)}"

# Define the Gradio Interface
with gr.Blocks(title="Malayalam FastConformer ASR") as demo:
    gr.Markdown("# 🎙️ Malayalam FastConformer Speech-to-Text")
    gr.Markdown("Upload an audio file or record from your microphone to generate a transcription.")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(type="filepath", label="Input Audio")
            transcribe_btn = gr.Button("Transcribe", variant="primary")
        
        with gr.Column():
            text_output = gr.Textbox(label="Transcription", lines=5)
            
    transcribe_btn.click(
        fn=transcribe,