| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| import sounddevice as sd | |
| import scipy.io.wavfile as wav | |
| import numpy as np | |
| import os | |
| MODEL_PATH = "model_best.pth" | |
| DURATION = 5 | |
| VOCAB_STR = "_abcçdefgğhıijklmnoöprsştuüvyzqwx " | |
| class ResCNNBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ResCNNBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.act1 = nn.GELU() | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.act2 = nn.GELU() | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.act1(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.act2(out) | |
| if residual.shape[1] == out.shape[1]: | |
| out += residual | |
| return out | |
| class DeepSpeechModel(nn.Module): | |
| def __init__(self, num_classes): | |
| super(DeepSpeechModel, self).__init__() | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, stride=2, padding=1), nn.GELU(), | |
| ResCNNBlock(32, 32), | |
| ResCNNBlock(32, 32), | |
| ResCNNBlock(32, 64), | |
| ResCNNBlock(64, 64), | |
| nn.Dropout(0.1) | |
| ) | |
| rnn_input_size = 64 * 64 | |
| self.dense = nn.Linear(rnn_input_size, 1024) | |
| self.layer_norm = nn.LayerNorm(1024) | |
| self.rnn = nn.LSTM(input_size=1024, hidden_size=512, num_layers=4, | |
| batch_first=True, bidirectional=True, dropout=0.3) | |
| self.classifier = nn.Linear(512*2, num_classes) | |
| def forward(self, x): | |
| x = self.cnn(x) | |
| b, c, t, f = x.shape | |
| x = x.permute(0, 2, 1, 3).contiguous().view(b, t, c*f) | |
| x = self.dense(x) | |
| x = self.layer_norm(x) | |
| x, _ = self.rnn(x) | |
| x = self.classifier(x) | |
| return x | |
| def greedy_decoder(output, vocab): | |
| arg_maxes = torch.argmax(output, dim=2).squeeze().tolist() | |
| decoded_chars = [] | |
| prev_index = -1 | |
| id_to_char = {i: char for i, char in enumerate(vocab)} | |
| for index in arg_maxes: | |
| if index != prev_index: | |
| if index != 0: | |
| char = id_to_char.get(index, "") | |
| decoded_chars.append(char) | |
| prev_index = index | |
| return "".join(decoded_chars) | |
| def record_audio(duration, fs, filename): | |
| print(f"\nRECORDING... ({duration} s)") | |
| try: | |
| recording = sd.rec(int(duration * fs), samplerate=fs, channels=1, dtype='float32') | |
| sd.wait() | |
| print("Recording finished.") | |
| wav.write(filename, fs, (recording * 32767).astype(np.int16)) | |
| except Exception as e: | |
| print(f"Recording Error: {e}") | |
| def predict(audio_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vocab = list(VOCAB_STR) | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"ERROR: {MODEL_PATH} not found!") | |
| return | |
| checkpoint = torch.load(MODEL_PATH, map_location=device,weights_only=True) | |
| saved_vocab_size = checkpoint['classifier.bias'].shape[0] | |
| if len(vocab) != saved_vocab_size: | |
| while len(vocab) < saved_vocab_size: | |
| vocab.append("?") | |
| model = DeepSpeechModel(num_classes=saved_vocab_size).to(device) | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != 16000: | |
| waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=16000, n_mels=128, n_fft=1024, hop_length=256 | |
| ).to(device) | |
| spec = mel_transform(waveform.to(device)) | |
| spec = torch.log(spec + 1e-9) | |
| if spec.dim() == 2: spec = spec.unsqueeze(0) | |
| spec = spec.unsqueeze(1) | |
| spec = spec.permute(0, 1, 3, 2) | |
| with torch.no_grad(): | |
| output = model(spec) | |
| text = greedy_decoder(output, vocab) | |
| text = text.replace("_", " ") | |
| print("-" * 40) | |
| print(f"RECOGNIZED: {text}") | |
| print("-" * 40) | |
| if __name__ == "__main__": | |
| temp_file = "live_final.wav" | |
| while True: | |
| user_input = input("Press Enter to record, q to exit: ") | |
| if user_input.lower() == "q": | |
| break | |
| record_audio(DURATION, 16000, temp_file) | |
| predict(temp_file) |