creativepurus commited on
Commit
07f0183
·
1 Parent(s): 4930096

Added App.py

Browse files
Files changed (2) hide show
  1. app.py +106 -0
  2. requirements.txt +178 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------- Type "python app.py" in TERMINAL to Run the App -------------------
2
+
3
+ import torch
4
+ import torchaudio
5
+ import gradio as gr
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
7
+ from safetensors.torch import load_file
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # ------------------- Label Mapping -------------------
12
+
13
+ id2label = {
14
+ 0: "Canadian English",
15
+ 1: "England English"
16
+ }
17
+
18
+ # ------------------- Load Processor -------------------
19
+
20
+ processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")
21
+
22
+ # ------------------- Define Model -------------------
23
+
24
+ class Wav2Vec2Classifier(nn.Module):
25
+ def __init__(self, num_labels):
26
+ super(Wav2Vec2Classifier, self).__init__()
27
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
28
+ self.dropout = nn.Dropout(0.2)
29
+ self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, num_labels)
30
+
31
+ def forward(self, input_values):
32
+ outputs = self.wav2vec2(input_values)
33
+ hidden_states = outputs.last_hidden_state
34
+ pooled_output = hidden_states.mean(dim=1)
35
+ logits = self.classifier(self.dropout(pooled_output))
36
+ return logits
37
+
38
+ # ------------------- Load Weights -------------------
39
+
40
+ model = Wav2Vec2Classifier(num_labels=2)
41
+ state_dict = load_file("model.safetensors", device="cpu") # assuming in root dir
42
+ model.load_state_dict(state_dict)
43
+ model.eval()
44
+
45
+ # ------------------- Prediction Function -------------------
46
+
47
+ def predict(audio_path):
48
+ # Load & preprocess audio
49
+ speech_array, sr = torchaudio.load(audio_path)
50
+ if sr != 16000:
51
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
52
+ speech_array = resampler(speech_array)
53
+
54
+ inputs = processor(
55
+ speech_array.squeeze().numpy(),
56
+ sampling_rate=16000,
57
+ return_tensors="pt",
58
+ padding="max_length",
59
+ truncation=True,
60
+ max_length=16000 * 4
61
+ )
62
+
63
+ with torch.no_grad():
64
+ logits = model(inputs.input_values)
65
+ probs = torch.nn.functional.softmax(logits, dim=-1)
66
+ pred_id = torch.argmax(probs, dim=-1).item()
67
+
68
+ return id2label[pred_id]
69
+
70
+ # ------------------- Gradio UI with Dark Theme -------------------
71
+
72
+ with gr.Blocks(
73
+ theme=gr.themes.Monochrome(primary_hue="blue", secondary_hue="purple", neutral_hue="slate"),
74
+ css="""
75
+ body { background-color: #1E1E2F !important; color: #E0E0E0 !important; }
76
+ .gr-button { background-color: #3B82F6 !important; color: white !important; font-weight: bold; }
77
+ .gr-textbox { font-size: 18px; }
78
+ .gr-audio label { color: white !important; }
79
+ """
80
+ ) as demo:
81
+ gr.Markdown(
82
+ """
83
+ <h1 style="text-align: center; color: #00FFFF;">🌍 Accent Classifier using Wav2Vec2</h1>
84
+ <p style="text-align: center; font-size: 16px;">Upload or record a 4-second <b>English voice clip</b><br>
85
+ This AI model detects whether your accent is <span style='color: #3B82F6; font-weight: bold;'>Canadian</span> or <span style='color: #FF4C4C; font-weight: bold;'>British</span>.</p>
86
+ <br>
87
+ """
88
+ )
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=1):
92
+ audio_input = gr.Audio(type="filepath", label="🎧 Upload or Record English Voice")
93
+ submit_btn = gr.Button("🔍 Detect Accent")
94
+
95
+ with gr.Column(scale=1):
96
+ label_output = gr.Text(label="🗣️ Predicted Accent")
97
+
98
+ submit_btn.click(fn=predict, inputs=audio_input, outputs=label_output)
99
+
100
+ gr.Markdown("---")
101
+ gr.Markdown(
102
+ "<p style='text-align: center;'>👨‍💻 Created by <a href='https://github.com/creativepurus' target='_blank' style='color:#66CFFF;'>Anand Purushottam</a> | <a href='https://www.linkedin.com/in/creativepurus/' target='_blank' style='color:#66CFFF;'>LinkedIn</a></p>"
103
+ )
104
+
105
+ if __name__ == "__main__":
106
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ accelerate==1.9.0
3
+ aiofiles==24.1.0
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.12.14
6
+ aiosignal==1.4.0
7
+ altair==5.5.0
8
+ annotated-types==0.7.0
9
+ anyio==4.9.0
10
+ asttokens==3.0.0
11
+ astunparse==1.6.3
12
+ attrs==24.3.0
13
+ audioread==3.0.1
14
+ blinker==1.9.0
15
+ Brotli==1.1.0
16
+ cachetools==6.1.0
17
+ certifi==2025.7.14
18
+ cffi==1.17.1
19
+ charset-normalizer==3.4.2
20
+ click==8.2.1
21
+ colorama==0.4.6
22
+ comm==0.2.2
23
+ contourpy==1.3.2
24
+ cycler==0.12.1
25
+ datasets==4.0.0
26
+ debugpy==1.8.14
27
+ decorator==5.2.1
28
+ dill==0.3.8
29
+ executing==2.2.0
30
+ fastapi==0.116.1
31
+ ffmpy==0.6.1
32
+ filelock==3.18.0
33
+ flatbuffers==25.2.10
34
+ fonttools==4.58.5
35
+ fpdf==1.7.2
36
+ frozenlist==1.7.0
37
+ fsspec==2025.3.0
38
+ gast==0.6.0
39
+ gitdb==4.0.12
40
+ GitPython==3.1.44
41
+ google-pasta==0.2.0
42
+ gradio==5.38.2
43
+ gradio_client==1.11.0
44
+ groovy==0.1.2
45
+ grpcio==1.73.1
46
+ h11==0.16.0
47
+ h5py==3.14.0
48
+ hf-xet==1.1.5
49
+ httpcore==1.0.9
50
+ httpx==0.28.1
51
+ huggingface-hub==0.34.1
52
+ idna==3.10
53
+ ipykernel==6.29.5
54
+ ipython==9.4.0
55
+ ipython_pygments_lexers==1.1.1
56
+ ipywidgets==8.1.7
57
+ jedi==0.19.2
58
+ Jinja2==3.1.6
59
+ joblib==1.5.1
60
+ jsonschema==4.24.0
61
+ jsonschema-specifications==2025.4.1
62
+ jupyter_client==8.6.3
63
+ jupyter_core==5.8.1
64
+ jupyterlab_widgets==3.0.15
65
+ keras==3.10.0
66
+ kiwisolver==1.4.8
67
+ lazy_loader==0.4
68
+ libclang==18.1.1
69
+ librosa==0.11.0
70
+ llvmlite==0.44.0
71
+ Markdown==3.8.2
72
+ markdown-it-py==3.0.0
73
+ MarkupSafe==3.0.2
74
+ matplotlib==3.10.3
75
+ matplotlib-inline==0.1.7
76
+ mdurl==0.1.2
77
+ ml_dtypes==0.5.1
78
+ mpmath==1.3.0
79
+ msgpack==1.1.1
80
+ multidict==6.6.3
81
+ multiprocess==0.70.16
82
+ namex==0.1.0
83
+ narwhals==1.47.0
84
+ nest-asyncio==1.6.0
85
+ networkx==3.5
86
+ numba==0.61.2
87
+ numpy==1.26.4
88
+ opt_einsum==3.4.0
89
+ optree==0.16.0
90
+ orjson==3.11.1
91
+ outcome==1.3.0.post0
92
+ packaging==25.0
93
+ pandas==2.3.1
94
+ parso==0.8.4
95
+ pillow==11.3.0
96
+ platformdirs==4.3.8
97
+ pooch==1.8.2
98
+ prompt_toolkit==3.0.51
99
+ propcache==0.3.2
100
+ protobuf==5.29.5
101
+ psutil==7.0.0
102
+ pure_eval==0.2.3
103
+ pyarrow==20.0.0
104
+ pycparser==2.22
105
+ pydantic==2.11.7
106
+ pydantic_core==2.33.2
107
+ pydeck==0.9.1
108
+ pydub==0.25.1
109
+ Pygments==2.19.2
110
+ pyparsing==3.2.3
111
+ PySocks==1.7.1
112
+ python-dateutil==2.9.0.post0
113
+ python-multipart==0.0.20
114
+ pytz==2025.2
115
+ pywin32==311
116
+ PyYAML==6.0.2
117
+ pyzmq==27.0.0
118
+ referencing==0.36.2
119
+ regex==2024.11.6
120
+ requests==2.32.4
121
+ rich==14.0.0
122
+ rpds-py==0.26.0
123
+ ruff==0.12.5
124
+ safehttpx==0.1.6
125
+ safetensors==0.5.3
126
+ scikit-learn==1.7.0
127
+ scipy==1.16.0
128
+ seaborn==0.13.2
129
+ selenium==4.27.1
130
+ semantic-version==2.10.0
131
+ setuptools==80.9.0
132
+ shellingham==1.5.4
133
+ six==1.17.0
134
+ smmap==5.0.2
135
+ sniffio==1.3.1
136
+ sortedcontainers==2.4.0
137
+ soundfile==0.13.1
138
+ soxr==0.5.0.post1
139
+ stack-data==0.6.3
140
+ starlette==0.47.2
141
+ streamlit==1.46.1
142
+ sympy==1.13.1
143
+ tenacity==9.1.2
144
+ tensorboard==2.19.0
145
+ tensorboard-data-server==0.7.2
146
+ tensorflow==2.19.0
147
+ termcolor==3.1.0
148
+ tf_keras==2.19.0
149
+ threadpoolctl==3.6.0
150
+ tokenizers==0.19.1
151
+ toml==0.10.2
152
+ tomlkit==0.13.3
153
+ torch==2.5.1+cu121
154
+ torchaudio==2.5.1+cu121
155
+ torchvision==0.20.1+cu121
156
+ tornado==6.5.1
157
+ tqdm==4.67.1
158
+ traitlets==5.14.3
159
+ transformers==4.41.2
160
+ trio==0.27.0
161
+ trio-websocket==0.11.1
162
+ typer==0.16.0
163
+ typing-inspection==0.4.1
164
+ typing_extensions==4.14.1
165
+ tzdata==2025.2
166
+ urllib3==2.5.0
167
+ uvicorn==0.35.0
168
+ watchdog==6.0.0
169
+ wcwidth==0.2.13
170
+ websocket-client==1.8.0
171
+ websockets==15.0.1
172
+ Werkzeug==3.1.3
173
+ wheel==0.45.1
174
+ widgetsnbextension==4.0.14
175
+ wrapt==1.17.2
176
+ wsproto==1.2.0
177
+ xxhash==3.5.0
178
+ yarl==1.20.1