creativepurus commited on
Commit
b804c93
Β·
1 Parent(s): 7785622

Updated Model Path

Browse files
Files changed (2) hide show
  1. app.py +50 -92
  2. requirements.txt +5 -14
app.py CHANGED
@@ -1,106 +1,64 @@
1
- # ------------------- Type "python app.py" in TERMINAL to Run the App -------------------
2
-
3
- import torch
4
- import torchaudio
5
- import gradio as gr
6
  from transformers import Wav2Vec2Processor, Wav2Vec2Model
7
  from safetensors.torch import load_file
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- # ------------------- Label Mapping -------------------
12
 
13
- id2label = {
14
- 0: "Canadian English",
15
- 1: "England English"
16
- }
17
 
18
- # ------------------- Load Processor -------------------
 
19
 
20
- processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")
 
 
 
 
 
 
21
 
22
- # ------------------- Define Model -------------------
 
 
23
 
24
- class Wav2Vec2Classifier(nn.Module):
25
- def __init__(self, num_labels):
26
- super(Wav2Vec2Classifier, self).__init__()
27
- self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
28
- self.dropout = nn.Dropout(0.2)
29
- self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, num_labels)
30
 
31
  def forward(self, input_values):
32
- outputs = self.wav2vec2(input_values)
33
- hidden_states = outputs.last_hidden_state
34
- pooled_output = hidden_states.mean(dim=1)
35
- logits = self.classifier(self.dropout(pooled_output))
36
  return logits
37
 
38
- # ------------------- Load Weights -------------------
39
-
40
- model = Wav2Vec2Classifier(num_labels=2)
41
- state_dict = load_file("model.safetensors", device="cpu") # assuming in root dir
42
- model.load_state_dict(state_dict)
43
  model.eval()
44
 
45
- # ------------------- Prediction Function -------------------
46
-
47
  def predict(audio_path):
48
- # Load & preprocess audio
49
- speech_array, sr = torchaudio.load(audio_path)
50
- if sr != 16000:
51
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
52
- speech_array = resampler(speech_array)
53
-
54
- inputs = processor(
55
- speech_array.squeeze().numpy(),
56
- sampling_rate=16000,
57
- return_tensors="pt",
58
- padding="max_length",
59
- truncation=True,
60
- max_length=16000 * 4
61
- )
62
-
63
- with torch.no_grad():
64
- logits = model(inputs.input_values)
65
- probs = torch.nn.functional.softmax(logits, dim=-1)
66
- pred_id = torch.argmax(probs, dim=-1).item()
67
-
68
- return id2label[pred_id]
69
-
70
- # ------------------- Gradio UI with Dark Theme -------------------
71
-
72
- with gr.Blocks(
73
- theme=gr.themes.Monochrome(primary_hue="blue", secondary_hue="purple", neutral_hue="slate"),
74
- css="""
75
- body { background-color: #1E1E2F !important; color: #E0E0E0 !important; }
76
- .gr-button { background-color: #3B82F6 !important; color: white !important; font-weight: bold; }
77
- .gr-textbox { font-size: 18px; }
78
- .gr-audio label { color: white !important; }
79
- """
80
- ) as demo:
81
- gr.Markdown(
82
- """
83
- <h1 style="text-align: center; color: #00FFFF;">🌍 Accent Classifier using Wav2Vec2</h1>
84
- <p style="text-align: center; font-size: 16px;">Upload or record a 4-second <b>English voice clip</b><br>
85
- This AI model detects whether your accent is <span style='color: #3B82F6; font-weight: bold;'>Canadian</span> or <span style='color: #FF4C4C; font-weight: bold;'>British</span>.</p>
86
- <br>
87
- """
88
- )
89
-
90
- with gr.Row():
91
- with gr.Column(scale=1):
92
- audio_input = gr.Audio(type="filepath", label="🎧 Upload or Record English Voice")
93
- submit_btn = gr.Button("πŸ” Detect Accent")
94
-
95
- with gr.Column(scale=1):
96
- label_output = gr.Text(label="πŸ—£οΈ Predicted Accent")
97
-
98
- submit_btn.click(fn=predict, inputs=audio_input, outputs=label_output)
99
-
100
- gr.Markdown("---")
101
- gr.Markdown(
102
- "<p style='text-align: center;'>πŸ‘¨β€πŸ’» Created by <a href='https://github.com/creativepurus' target='_blank' style='color:#66CFFF;'>Anand Purushottam</a> | <a href='https://www.linkedin.com/in/creativepurus/' target='_blank' style='color:#66CFFF;'>LinkedIn</a></p>"
103
- )
104
-
105
- if __name__ == "__main__":
106
- demo.launch()
 
 
 
 
 
 
1
  from transformers import Wav2Vec2Processor, Wav2Vec2Model
2
  from safetensors.torch import load_file
3
+ import torch
4
+ import gradio as gr
5
+ import torchaudio
 
6
 
7
+ # Load processor from Hugging Face Model Hub
8
+ processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")
 
 
9
 
10
+ # Load base model (large version)
11
+ base_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
12
 
13
+ # Define your classifier head
14
+ classifier_head = torch.nn.Sequential(
15
+ torch.nn.AdaptiveAvgPool1d(1),
16
+ torch.nn.Flatten(),
17
+ torch.nn.Dropout(0.1),
18
+ torch.nn.Linear(1024, 2) # 1024 hidden size for wav2vec2-large
19
+ )
20
 
21
+ # Load fine-tuned classifier weights
22
+ state_dict = load_file("model.safetensors", device="cpu")
23
+ classifier_head.load_state_dict(state_dict)
24
 
25
+ # Combine base model + classifier head
26
+ class AccentClassifier(torch.nn.Module):
27
+ def __init__(self, base, head):
28
+ super().__init__()
29
+ self.base = base
30
+ self.head = head
31
 
32
  def forward(self, input_values):
33
+ with torch.no_grad():
34
+ features = self.base(input_values).last_hidden_state
35
+ logits = self.head(features.transpose(1, 2))
 
36
  return logits
37
 
38
+ model = AccentClassifier(base_model, classifier_head)
 
 
 
 
39
  model.eval()
40
 
41
+ # Inference function
 
42
  def predict(audio_path):
43
+ # Load and preprocess audio
44
+ waveform, sample_rate = torchaudio.load(audio_path)
45
+ if sample_rate != 16000:
46
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
47
+ waveform = resampler(waveform)
48
+
49
+ inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
50
+ logits = model(inputs.input_values)
51
+ probs = torch.nn.functional.softmax(logits, dim=1)
52
+
53
+ labels = ["Canadian English", "England English"]
54
+ return {labels[i]: float(probs[0][i]) for i in range(2)}
55
+
56
+ # Gradio Interface
57
+ interface = gr.Interface(
58
+ fn=predict,
59
+ inputs=gr.Audio(source="upload", type="filepath"),
60
+ outputs=gr.Label(num_top_classes=2),
61
+ title="Accent Classification with Wav2Vec2-Large"
62
+ )
63
+
64
+ interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,14 +1,5 @@
1
- fastapi==0.116.1
2
- gradio==5.38.2
3
- torch==2.5.1
4
- torchaudio==2.5.1
5
- transformers==4.41.2
6
- datasets==4.0.0
7
- huggingface-hub==0.34.1
8
- safetensors==0.5.3
9
- librosa==0.11.0
10
- soundfile==0.13.1
11
- pandas==2.3.1
12
- numpy==1.26.4
13
- scikit-learn==1.7.0
14
- uvicorn==0.35.0
 
1
+ torch
2
+ transformers
3
+ safetensors
4
+ torchaudio
5
+ gradio