karenlu653 commited on
Commit
5edb4b5
·
1 Parent(s): 1d674fa

added tab for dataset

Browse files
Files changed (2) hide show
  1. app.py +49 -11
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import torch
2
- import torch.nn as nn
3
  import librosa
4
  import numpy as np
5
  import json
6
  from huggingface_hub import hf_hub_download
7
- import gradio as gr
8
- import soundfile as sf
9
  from safetensors.torch import load_file
 
 
 
 
 
10
 
11
  # ----------------- Model definition -----------------
12
  class LanNetBinary(nn.Module):
@@ -97,15 +100,50 @@ def predict(audio_path):
97
 
98
  return label_map.get(str(pred), str(pred))
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # ----------------- Gradio Interface -----------------
102
- iface = gr.Interface(
103
- fn=predict,
104
- inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
105
- outputs="text",
106
- title="Dialect Classification Demo",
107
- description="Upload or record audio to classify if this is the Shanghai dialect!"
108
- )
 
 
 
 
 
 
 
109
 
110
  if __name__ == "__main__":
111
- iface.launch()
 
1
  import torch
2
+ import gradio as gr
3
  import librosa
4
  import numpy as np
5
  import json
6
  from huggingface_hub import hf_hub_download
 
 
7
  from safetensors.torch import load_file
8
+ from datasets import load_dataset
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from sklearn.metrics import confusion_matrix
12
+
13
 
14
  # ----------------- Model definition -----------------
15
  class LanNetBinary(nn.Module):
 
100
 
101
  return label_map.get(str(pred), str(pred))
102
 
103
+ def evaluate_dataset():
104
+ ds = load_dataset("karenlu653/dialect_model_demo", split="train")
105
+
106
+ y_true, y_pred = [], []
107
+ for row in ds:
108
+ y = np.array(row["audio"], dtype=np.float32)
109
+ sr = preproc.get("sampling_rate", 16000)
110
+ feats = extract_features(y, sr).to(device)
111
+ with torch.no_grad():
112
+ logits = model(feats)
113
+ pred = int(logits.argmax(dim=1))
114
+ y_pred.append(pred)
115
+ y_true.append(row["label"])
116
+
117
+ # Confusion matrix
118
+ labels = sorted(set(y_true))
119
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
120
+
121
+ # Plot
122
+ fig, ax = plt.subplots(figsize=(5, 4))
123
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[label_map[str(l)] for l in labels],
124
+ yticklabels=[label_map[str(l)] for l in labels], ax=ax)
125
+ ax.set_xlabel("Predicted")
126
+ ax.set_ylabel("True")
127
+ ax.set_title("Confusion Matrix of Shanghai Demo Samples")
128
+ plt.tight_layout()
129
+
130
+ return fig
131
 
132
  # ----------------- Gradio Interface -----------------
133
+
134
+ with gr.Blocks() as demo:
135
+ with gr.Tab("Single Prediction"):
136
+ gr.Interface(
137
+ fn=predict,
138
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
139
+ outputs="text",
140
+ description = "Upload or record audio to classify if this is the Shanghai dialect!",
141
+ live=False
142
+ )
143
+ with gr.Tab("Dataset Evaluation"):
144
+ eval_btn = gr.Button("Run Evaluation on Uploaded Dataset")
145
+ eval_output = gr.Plot()
146
+ eval_btn.click(evaluate_dataset, inputs=None, outputs=eval_output)
147
 
148
  if __name__ == "__main__":
149
+ demo.launch()
requirements.txt CHANGED
@@ -5,4 +5,7 @@ gradio
5
  safetensors
6
  huggingface_hub
7
  numpy
8
- soundfile
 
 
 
 
5
  safetensors
6
  huggingface_hub
7
  numpy
8
+ soundfile
9
+ scikit-learn
10
+ seaborn
11
+ matplotlib