dialect-demo / app.py
karenlu653's picture
fixing audio upload
1d674fa
raw
history blame
3.43 kB
import torch
import torch.nn as nn
import librosa
import numpy as np
import json
from huggingface_hub import hf_hub_download
import gradio as gr
import soundfile as sf
from safetensors.torch import load_file
# ----------------- Model definition -----------------
class LanNetBinary(nn.Module):
def __init__(self, input_dim=40, hidden_dim=512, num_layers=2):
super().__init__()
self.gru = nn.GRU(input_dim, hidden_dim,
num_layers=num_layers, batch_first=True)
self.linear2 = nn.Linear(hidden_dim, 192)
self.linear3 = nn.Linear(192, 2)
def forward(self, x):
out, _ = self.gru(x)
last = out[:, -1, :]
x = self.linear2(last)
x = self.linear3(x)
return x
# ----------------- Load config + model -----------------
REPO_ID = "karenlu653/dialect_model_naive"
# Load configs
config = json.load(open(hf_hub_download(REPO_ID, "config.json"), "r"))
preproc = json.load(open(hf_hub_download(REPO_ID, "preprocessor_config.json"), "r"))
label_map = json.load(open(hf_hub_download(REPO_ID, "label_mapping.json"), "r"))
# Instantiate model with correct params
model = LanNetBinary(
input_dim=config.get("input_dim", 40),
hidden_dim=config.get("hidden_dim", 512),
num_layers=config.get("num_layers", 2)
)
state_dict = load_file(hf_hub_download(REPO_ID, "model.safetensors"))
model.load_state_dict(state_dict)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# ----------------- Feature extraction -----------------
def extract_features(y, sr):
n_mels = preproc.get("n_mels", 40)
n_fft = preproc.get("n_fft", 400)
hop_length = preproc.get("hop_length", 160)
max_len = preproc.get("max_len_frames", 200)
# Resample if needed
target_sr = preproc.get("sampling_rate", 16000)
if sr != target_sr:
y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
sr = target_sr
# Mel spectrogram
mel = librosa.feature.melspectrogram(
y=y, sr=sr, n_mels=n_mels,
n_fft=n_fft, hop_length=hop_length, power=2.0
)
fbanks = librosa.power_to_db(mel).T
# Pad/truncate
if fbanks.shape[0] < max_len:
fbanks = np.pad(fbanks, ((0, max_len - fbanks.shape[0]), (0, 0)), mode="constant")
else:
fbanks = fbanks[:max_len, :]
return torch.tensor(fbanks, dtype=torch.float32).unsqueeze(0) # (1, T, F)
# ----------------- Prediction function -----------------
import tempfile, shutil
def predict(audio_path):
if not audio_path:
return "No audio provided"
# Copy to a safe temp file
tmp_path = tempfile.mktemp(suffix=".wav")
shutil.copy(audio_path, tmp_path)
import soundfile as sf
y, sr = sf.read(tmp_path, dtype="float32")
if len(y) == 0:
return "No audio detected, please try again."
feats = extract_features(y, sr).to(device)
with torch.no_grad():
logits = model(feats)
pred = int(logits.argmax(dim=1))
return label_map.get(str(pred), str(pred))
# ----------------- Gradio Interface -----------------
iface = gr.Interface(
fn=predict,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
outputs="text",
title="Dialect Classification Demo",
description="Upload or record audio to classify if this is the Shanghai dialect!"
)
if __name__ == "__main__":
iface.launch()