baharbhz commited on
Commit
55a8763
·
verified ·
1 Parent(s): df2bbbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -35
app.py CHANGED
@@ -21,6 +21,9 @@ import torch
21
  import librosa
22
  import torchaudio
23
  import numpy as np
 
 
 
24
 
25
 
26
  url = "https://huggingface.co/MahtaFetrat/tempmodel/resolve/main/checkpoint-15-1200.zip"
@@ -34,21 +37,47 @@ output_dir = "extracted_model"
34
  subprocess.run(["unzip", zip_file, "-d", output_dir], check=True)
35
 
36
 
37
- from transformers import Wav2Vec2CTCTokenizer
 
 
 
38
 
39
- tokenizer = Wav2Vec2CTCTokenizer("vocab.json", unk_token="<unk>", pad_token="<pad>", word_delimiter_token="|")
 
 
40
 
 
 
 
 
 
 
 
41
 
42
- from transformers import Wav2Vec2FeatureExtractor
 
43
 
44
- feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
 
45
 
 
 
 
 
46
 
47
- from transformers import Wav2Vec2Processor
 
 
48
 
49
- tuned_wav2vec_processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
50
- tuned_wav2vec_model = Wav2Vec2ForCTC.from_pretrained("extracted_model/checkpoint-15-1200")
 
 
 
 
51
 
 
 
52
 
53
  def tuned_wav2vec_speech_file_to_array_fn(path):
54
  speech_array, sampling_rate = torchaudio.load(path)
@@ -88,34 +117,8 @@ def preprocess_audio(audio_path):
88
 
89
 
90
  def speech_to_text(audio_path):
91
- # waveform = preprocess_audio(audio_path)
92
-
93
- # input_values = processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
94
- # with torch.no_grad():
95
- # logits = model(input_values).logits
96
- # predicted_ids = torch.argmax(logits, dim=-1)
97
- # transcription = processor.batch_decode(predicted_ids)[0]
98
- # return transcription
99
-
100
- speech = tuned_wav2vec_speech_file_to_array_fn(audio_path)
101
-
102
- features = tuned_wav2vec_processor(
103
- speech,
104
- sampling_rate=tuned_wav2vec_processor.feature_extractor.sampling_rate,
105
- return_tensors="pt",
106
- padding=True
107
- )
108
-
109
- input_values = features.input_values
110
- attention_mask = features.attention_mask
111
-
112
- with torch.no_grad():
113
- logits = tuned_wav2vec_model(input_values, attention_mask=attention_mask).logits
114
-
115
- pred_ids = torch.argmax(logits, dim=-1)
116
-
117
- predicted = tuned_wav2vec_processor.batch_decode(pred_ids)
118
- return predicted[0]
119
 
120
 
121
  def video_to_text(video_path):
 
21
  import librosa
22
  import torchaudio
23
  import numpy as np
24
+ import torch
25
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer
26
+ import librosa
27
 
28
 
29
  url = "https://huggingface.co/MahtaFetrat/tempmodel/resolve/main/checkpoint-15-1200.zip"
 
37
  subprocess.run(["unzip", zip_file, "-d", output_dir], check=True)
38
 
39
 
40
+ # Function for inference from an audio file path
41
+ def infer_from_audio_file(audio_file_path, model, processor, device="cpu"):
42
+ # Load audio file
43
+ audio, sampling_rate = librosa.load(audio_file_path, sr=16000)
44
 
45
+ # Process the audio using the feature extractor from the processor
46
+ inputs = processor(audio, sampling_rate=sampling_rate).input_values[0]
47
+ input_features = [{"input_values": inputs}]
48
 
49
+ batch = processor.pad(
50
+ input_features,
51
+ padding=True,
52
+ max_length=None,
53
+ pad_to_multiple_of=None,
54
+ return_tensors="pt",
55
+ )
56
 
57
+ # Move inputs to the correct device
58
+ input_values = batch.input_values.to(device)
59
 
60
+ # Ensure the model is in evaluation mode
61
+ model.eval()
62
 
63
+ with torch.no_grad():
64
+ # Make predictions
65
+ outputs = model(input_values)
66
+ logits = outputs.logits
67
 
68
+ # Decode the predictions
69
+ pred_ids = torch.argmax(logits, dim=-1)
70
+ pred_str = processor.batch_decode(pred_ids.cpu().numpy())
71
 
72
+ return pred_str[0] # Return the decoded transcription of the audio
73
+
74
+
75
+ tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="<unk>", pad_token="<pad>", word_delimiter_token="|")
76
+ feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
77
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
78
 
79
+ latest_checkpoint = "extracted_model/checkpoint-15-1200"
80
+ model = Wav2Vec2ForCTC.from_pretrained(latest_checkpoint)
81
 
82
  def tuned_wav2vec_speech_file_to_array_fn(path):
83
  speech_array, sampling_rate = torchaudio.load(path)
 
117
 
118
 
119
  def speech_to_text(audio_path):
120
+ predicted = infer_from_audio_file(audio_path, model, processor)
121
+ return predicted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
  def video_to_text(video_path):