creativepurus commited on
Commit
17f9f88
·
1 Parent(s): b804c93

updated Model Path

Browse files
Files changed (1) hide show
  1. app.py +53 -44
app.py CHANGED
@@ -1,64 +1,73 @@
1
- from transformers import Wav2Vec2Processor, Wav2Vec2Model
2
- from safetensors.torch import load_file
3
  import torch
4
- import gradio as gr
5
  import torchaudio
 
 
 
 
6
 
7
- # Load processor from Hugging Face Model Hub
8
- processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")
 
9
 
10
- # Load base model (large version)
11
- base_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
 
 
12
 
13
- # Define your classifier head
14
- classifier_head = torch.nn.Sequential(
15
- torch.nn.AdaptiveAvgPool1d(1),
16
- torch.nn.Flatten(),
17
- torch.nn.Dropout(0.1),
18
- torch.nn.Linear(1024, 2) # 1024 hidden size for wav2vec2-large
19
- )
 
20
 
21
- # Load fine-tuned classifier weights
22
- state_dict = load_file("model.safetensors", device="cpu")
23
- classifier_head.load_state_dict(state_dict)
24
 
25
- # Combine base model + classifier head
26
- class AccentClassifier(torch.nn.Module):
27
- def __init__(self, base, head):
28
- super().__init__()
29
- self.base = base
30
- self.head = head
 
31
 
32
- def forward(self, input_values):
33
- with torch.no_grad():
34
- features = self.base(input_values).last_hidden_state
35
- logits = self.head(features.transpose(1, 2))
 
 
36
  return logits
37
 
38
- model = AccentClassifier(base_model, classifier_head)
 
 
39
  model.eval()
40
 
41
- # Inference function
42
- def predict(audio_path):
43
- # Load and preprocess audio
44
- waveform, sample_rate = torchaudio.load(audio_path)
45
  if sample_rate != 16000:
46
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
47
  waveform = resampler(waveform)
48
-
49
- inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
50
- logits = model(inputs.input_values)
51
- probs = torch.nn.functional.softmax(logits, dim=1)
52
-
53
- labels = ["Canadian English", "England English"]
54
- return {labels[i]: float(probs[0][i]) for i in range(2)}
55
 
56
- # Gradio Interface
57
  interface = gr.Interface(
58
- fn=predict,
59
- inputs=gr.Audio(source="upload", type="filepath"),
60
- outputs=gr.Label(num_top_classes=2),
61
- title="Accent Classification with Wav2Vec2-Large"
 
 
62
  )
63
 
64
  interface.launch()
 
 
 
1
  import torch
 
2
  import torchaudio
3
+ import gradio as gr
4
+ import os
5
+ import time
6
+ import numpy as np
7
 
8
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
9
+ from safetensors.torch import load_file
10
+ import torch.nn as nn
11
 
12
+ # Wait for model.safetensors to be available
13
+ model_path = "model.safetensors"
14
+ timeout = 300 # wait up to 5 minutes
15
+ start_time = time.time()
16
 
17
+ while not os.path.exists(model_path):
18
+ if time.time() - start_time > timeout:
19
+ raise TimeoutError(f"{model_path} not found after {timeout} seconds.")
20
+ print(f"Waiting for {model_path} to be downloaded...")
21
+ time.sleep(5)
22
+
23
+ # Load processor
24
+ processor = Wav2Vec2Processor.from_pretrained("creativepurus/accent-wav2vec2")
25
 
26
+ # Load model weights from model.safetensors
27
+ state_dict = load_file(model_path, device="cpu")
 
28
 
29
+ # Define the same model architecture used during training
30
+ class Wav2Vec2Classifier(nn.Module):
31
+ def __init__(self):
32
+ super(Wav2Vec2Classifier, self).__init__()
33
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
34
+ self.dropout = nn.Dropout(0.3)
35
+ self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, 2)
36
 
37
+ def forward(self, input_values, attention_mask=None):
38
+ outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
39
+ hidden_states = outputs.last_hidden_state
40
+ pooled = hidden_states.mean(dim=1)
41
+ pooled = self.dropout(pooled)
42
+ logits = self.classifier(pooled)
43
  return logits
44
 
45
+ # Instantiate and load the model
46
+ model = Wav2Vec2Classifier()
47
+ model.load_state_dict(state_dict)
48
  model.eval()
49
 
50
+ # Prediction function
51
+ def predict_accent(audio):
52
+ waveform, sample_rate = torchaudio.load(audio)
 
53
  if sample_rate != 16000:
54
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
55
  waveform = resampler(waveform)
56
+ input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
57
+ with torch.no_grad():
58
+ logits = model(input_values)
59
+ predicted_class_id = logits.argmax().item()
60
+ label_map = {0: "Canadian English", 1: "England English"}
61
+ return label_map[predicted_class_id]
 
62
 
63
+ # Gradio UI
64
  interface = gr.Interface(
65
+ fn=predict_accent,
66
+ inputs=gr.Audio(source="upload", type="filepath", label="Upload or Record Audio (WAV)"),
67
+ outputs=gr.Textbox(label="Predicted Accent"),
68
+ title="Accent Classification",
69
+ description="This app classifies English accents as either Canadian or England using a fine-tuned Wav2Vec2 model.",
70
+ allow_flagging="never"
71
  )
72
 
73
  interface.launch()