Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import librosa | |
| import numpy as np | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import soundfile as sf | |
| from safetensors.torch import load_file | |
| # ----------------- Model definition ----------------- | |
| class LanNetBinary(nn.Module): | |
| def __init__(self, input_dim=40, hidden_dim=512, num_layers=2): | |
| super().__init__() | |
| self.gru = nn.GRU(input_dim, hidden_dim, | |
| num_layers=num_layers, batch_first=True) | |
| self.linear2 = nn.Linear(hidden_dim, 192) | |
| self.linear3 = nn.Linear(192, 2) | |
| def forward(self, x): | |
| out, _ = self.gru(x) | |
| last = out[:, -1, :] | |
| x = self.linear2(last) | |
| x = self.linear3(x) | |
| return x | |
| # ----------------- Load config + model ----------------- | |
| REPO_ID = "karenlu653/dialect_model_naive" | |
| # Load configs | |
| config = json.load(open(hf_hub_download(REPO_ID, "config.json"), "r")) | |
| preproc = json.load(open(hf_hub_download(REPO_ID, "preprocessor_config.json"), "r")) | |
| label_map = json.load(open(hf_hub_download(REPO_ID, "label_mapping.json"), "r")) | |
| # Instantiate model with correct params | |
| model = LanNetBinary( | |
| input_dim=config.get("input_dim", 40), | |
| hidden_dim=config.get("hidden_dim", 512), | |
| num_layers=config.get("num_layers", 2) | |
| ) | |
| state_dict = load_file(hf_hub_download(REPO_ID, "model.safetensors")) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # ----------------- Feature extraction ----------------- | |
| def extract_features(y, sr): | |
| n_mels = preproc.get("n_mels", 40) | |
| n_fft = preproc.get("n_fft", 400) | |
| hop_length = preproc.get("hop_length", 160) | |
| max_len = preproc.get("max_len_frames", 200) | |
| # Resample if needed | |
| target_sr = preproc.get("sampling_rate", 16000) | |
| if sr != target_sr: | |
| y = librosa.resample(y, orig_sr=sr, target_sr=target_sr) | |
| sr = target_sr | |
| # Mel spectrogram | |
| mel = librosa.feature.melspectrogram( | |
| y=y, sr=sr, n_mels=n_mels, | |
| n_fft=n_fft, hop_length=hop_length, power=2.0 | |
| ) | |
| fbanks = librosa.power_to_db(mel).T | |
| # Pad/truncate | |
| if fbanks.shape[0] < max_len: | |
| fbanks = np.pad(fbanks, ((0, max_len - fbanks.shape[0]), (0, 0)), mode="constant") | |
| else: | |
| fbanks = fbanks[:max_len, :] | |
| return torch.tensor(fbanks, dtype=torch.float32).unsqueeze(0) # (1, T, F) | |
| # ----------------- Prediction function ----------------- | |
| import tempfile, shutil | |
| def predict(audio_path): | |
| if not audio_path: | |
| return "No audio provided" | |
| # Copy to a safe temp file | |
| tmp_path = tempfile.mktemp(suffix=".wav") | |
| shutil.copy(audio_path, tmp_path) | |
| import soundfile as sf | |
| y, sr = sf.read(tmp_path, dtype="float32") | |
| if len(y) == 0: | |
| return "No audio detected, please try again." | |
| feats = extract_features(y, sr).to(device) | |
| with torch.no_grad(): | |
| logits = model(feats) | |
| pred = int(logits.argmax(dim=1)) | |
| return label_map.get(str(pred), str(pred)) | |
| # ----------------- Gradio Interface ----------------- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"), | |
| outputs="text", | |
| title="Dialect Classification Demo", | |
| description="Upload or record audio to classify if this is the Shanghai dialect!" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |