STT-for-Turkish / test.py
Aytacus's picture
Update test.py
f490873 verified
import torch
import torch.nn as nn
import torchaudio
import sounddevice as sd
import scipy.io.wavfile as wav
import numpy as np
import os
MODEL_PATH = "model_best.pth"
DURATION = 5
VOCAB_STR = "_abcçdefgğhıijklmnoöprsştuüvyzqwx "
class ResCNNBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResCNNBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.act1 = nn.GELU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.act2 = nn.GELU()
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out)
if residual.shape[1] == out.shape[1]:
out += residual
return out
class DeepSpeechModel(nn.Module):
def __init__(self, num_classes):
super(DeepSpeechModel, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1), nn.GELU(),
ResCNNBlock(32, 32),
ResCNNBlock(32, 32),
ResCNNBlock(32, 64),
ResCNNBlock(64, 64),
nn.Dropout(0.1)
)
rnn_input_size = 64 * 64
self.dense = nn.Linear(rnn_input_size, 1024)
self.layer_norm = nn.LayerNorm(1024)
self.rnn = nn.LSTM(input_size=1024, hidden_size=512, num_layers=4,
batch_first=True, bidirectional=True, dropout=0.3)
self.classifier = nn.Linear(512*2, num_classes)
def forward(self, x):
x = self.cnn(x)
b, c, t, f = x.shape
x = x.permute(0, 2, 1, 3).contiguous().view(b, t, c*f)
x = self.dense(x)
x = self.layer_norm(x)
x, _ = self.rnn(x)
x = self.classifier(x)
return x
def greedy_decoder(output, vocab):
arg_maxes = torch.argmax(output, dim=2).squeeze().tolist()
decoded_chars = []
prev_index = -1
id_to_char = {i: char for i, char in enumerate(vocab)}
for index in arg_maxes:
if index != prev_index:
if index != 0:
char = id_to_char.get(index, "")
decoded_chars.append(char)
prev_index = index
return "".join(decoded_chars)
def record_audio(duration, fs, filename):
print(f"\nRECORDING... ({duration} s)")
try:
recording = sd.rec(int(duration * fs), samplerate=fs, channels=1, dtype='float32')
sd.wait()
print("Recording finished.")
wav.write(filename, fs, (recording * 32767).astype(np.int16))
except Exception as e:
print(f"Recording Error: {e}")
def predict(audio_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab = list(VOCAB_STR)
if not os.path.exists(MODEL_PATH):
print(f"ERROR: {MODEL_PATH} not found!")
return
checkpoint = torch.load(MODEL_PATH, map_location=device,weights_only=True)
saved_vocab_size = checkpoint['classifier.bias'].shape[0]
if len(vocab) != saved_vocab_size:
while len(vocab) < saved_vocab_size:
vocab.append("?")
model = DeepSpeechModel(num_classes=saved_vocab_size).to(device)
model.load_state_dict(checkpoint)
model.eval()
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_mels=128, n_fft=1024, hop_length=256
).to(device)
spec = mel_transform(waveform.to(device))
spec = torch.log(spec + 1e-9)
if spec.dim() == 2: spec = spec.unsqueeze(0)
spec = spec.unsqueeze(1)
spec = spec.permute(0, 1, 3, 2)
with torch.no_grad():
output = model(spec)
text = greedy_decoder(output, vocab)
text = text.replace("_", " ")
print("-" * 40)
print(f"RECOGNIZED: {text}")
print("-" * 40)
if __name__ == "__main__":
temp_file = "live_final.wav"
while True:
user_input = input("Press Enter to record, q to exit: ")
if user_input.lower() == "q":
break
record_audio(DURATION, 16000, temp_file)
predict(temp_file)