ArseniyPerchik commited on
Commit
d250443
·
1 Parent(s): 9bf5ef6
app.py CHANGED
@@ -7,6 +7,7 @@ import matplotlib.animation as animation
7
  import tempfile
8
  import torch
9
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
 
10
  import torchaudio
11
  import torchaudio.transforms as T
12
  from matplotlib.patches import Circle
@@ -20,22 +21,10 @@ from types import SimpleNamespace
20
  # ---------------------------- #
21
  # models
22
  # a model for the automatic-speech-recognition task
23
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
24
- # torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
- # model_id = "./models_for_proj/librispeech_asr_dummy"
26
- # model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
27
- # model.to(device)
28
- # processor = AutoProcessor.from_pretrained(model_id)
29
- # asr_pipe = pipeline(
30
- # "automatic-speech-recognition",
31
- # model=model,
32
- # tokenizer=processor.tokenizer,
33
- # feature_extractor=processor.feature_extractor,
34
- # max_new_tokens=128,
35
- # torch_dtype=torch_dtype,
36
- # device=device,
37
- # )
38
- asr_pipe_default = pipeline("automatic-speech-recognition")
39
 
40
 
41
  # env variables
@@ -62,10 +51,10 @@ r_coverage = 10
62
  # ---------------------------- #
63
  def create_standing_animation():
64
  path = [(agent_pos.x, agent_pos.y)]
65
- return create_animation(path, r_coverage)
66
 
67
 
68
- def create_animation(path, r_coverage):
69
  # path = [(i,i) for i in range(90)]
70
  # targets_x = [20, 80, 80, 20]
71
  # targets_y = [20, 20, 80, 80]
@@ -135,7 +124,7 @@ def move_agent(target_input: int):
135
  agent_pos.x = path[-1][0]
136
  agent_pos.y = path[-1][1]
137
  # create animation
138
- video_output = create_animation(path, r_coverage)
139
 
140
  # update status
141
  status = f'Went to target {target_input}.'
@@ -147,24 +136,33 @@ def load_image_on_start():
147
  return np.random.rand(700, 700)
148
 
149
  def get_text_request(audio_input):
 
150
  audio_input_sr, audio_input_np = audio_input
151
  audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32)
152
  target_sr = 16000
153
  resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype)
154
  resampled_audio_input_t: torch.Tensor = resampler(audio_input_t)
155
  resampled_audio_input_np = resampled_audio_input_t.numpy()
156
- # result = asr_pipe(resampled_audio_input_np)
157
- result = asr_pipe_default(resampled_audio_input_np)
158
- return result["text"]
 
 
 
 
 
 
 
 
159
 
160
  def get_target_from_request(request_text):
161
- if 'ONE' in request_text:
162
  return 1
163
- if 'TWO' in request_text:
164
  return 2
165
- if 'THREE' in request_text:
166
  return 3
167
- if 'FOUR' in request_text:
168
  return 4
169
  return 'NO TARGET FOUND'
170
 
@@ -190,6 +188,7 @@ def create_demo():
190
  - insert a model that understands the desired goal and not to use a simple function for it that can produce false goals
191
  - to incorporate a longer chain of goals; for example, go there and pick the package, then come back
192
  - to introduce additional learnt capabilities
 
193
  """)
194
 
195
  # EVENTS:
@@ -208,3 +207,21 @@ def create_demo():
208
  # ---------------------------- #
209
  demo = create_demo()
210
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import tempfile
8
  import torch
9
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
10
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
11
  import torchaudio
12
  import torchaudio.transforms as T
13
  from matplotlib.patches import Circle
 
21
  # ---------------------------- #
22
  # models
23
  # a model for the automatic-speech-recognition task
24
+ # asr_pipe_default = pipeline("automatic-speech-recognition")
25
+ save_dir = './models_for_proj/wav2vec2-base-960h'
26
+ model = Wav2Vec2ForCTC.from_pretrained(save_dir)
27
+ processor = Wav2Vec2Processor.from_pretrained(save_dir)
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  # env variables
 
51
  # ---------------------------- #
52
  def create_standing_animation():
53
  path = [(agent_pos.x, agent_pos.y)]
54
+ return create_animation(path)
55
 
56
 
57
+ def create_animation(path):
58
  # path = [(i,i) for i in range(90)]
59
  # targets_x = [20, 80, 80, 20]
60
  # targets_y = [20, 20, 80, 80]
 
124
  agent_pos.x = path[-1][0]
125
  agent_pos.y = path[-1][1]
126
  # create animation
127
+ video_output = create_animation(path)
128
 
129
  # update status
130
  status = f'Went to target {target_input}.'
 
136
  return np.random.rand(700, 700)
137
 
138
  def get_text_request(audio_input):
139
+ # --------------------------------------------------------------------------- #
140
  audio_input_sr, audio_input_np = audio_input
141
  audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32)
142
  target_sr = 16000
143
  resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype)
144
  resampled_audio_input_t: torch.Tensor = resampler(audio_input_t)
145
  resampled_audio_input_np = resampled_audio_input_t.numpy()
146
+ # --------------------------------------------------------------------------- #
147
+ # result = asr_pipe_default(resampled_audio_input_np)
148
+ inputs = processor(resampled_audio_input_np, sampling_rate=16000, return_tensors="pt", padding=True)
149
+ # Inference
150
+ with torch.no_grad():
151
+ logits = model(**inputs).logits
152
+ # Decode
153
+ predicted_ids = torch.argmax(logits, dim=-1)
154
+ transcription = processor.decode(predicted_ids[0])
155
+ # print("Transcription:", transcription)
156
+ return transcription
157
 
158
  def get_target_from_request(request_text):
159
+ if any(item in request_text for item in ['ONE', 'FIRST']):
160
  return 1
161
+ if any(item in request_text for item in ['TWO', 'SECOND']):
162
  return 2
163
+ if any(item in request_text for item in ['THREE', 'THIRD']):
164
  return 3
165
+ if any(item in request_text for item in ['FOUR', 'FOURTH']):
166
  return 4
167
  return 'NO TARGET FOUND'
168
 
 
188
  - insert a model that understands the desired goal and not to use a simple function for it that can produce false goals
189
  - to incorporate a longer chain of goals; for example, go there and pick the package, then come back
190
  - to introduce additional learnt capabilities
191
+ - to build more complex environments where the movement is not so straightforward
192
  """)
193
 
194
  # EVENTS:
 
207
  # ---------------------------- #
208
  demo = create_demo()
209
  demo.launch()
210
+
211
+
212
+
213
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
214
+ # torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
215
+ # model_id = "./models_for_proj/librispeech_asr_dummy"
216
+ # model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
217
+ # model.to(device)
218
+ # processor = AutoProcessor.from_pretrained(model_id)
219
+ # asr_pipe = pipeline(
220
+ # "automatic-speech-recognition",
221
+ # model=model,
222
+ # tokenizer=processor.tokenizer,
223
+ # feature_extractor=processor.feature_extractor,
224
+ # max_new_tokens=128,
225
+ # torch_dtype=torch_dtype,
226
+ # device=device,
227
+ # )
draft_1.ipynb CHANGED
@@ -1,618 +1,229 @@
1
  {
2
  "cells": [
3
  {
4
- "cell_type": "code",
5
- "id": "bf22c176a849df32",
 
 
 
 
6
  "metadata": {
7
  "ExecuteTime": {
8
- "end_time": "2025-04-21T06:10:01.065321Z",
9
- "start_time": "2025-04-21T06:10:01.060267Z"
10
  }
11
  },
 
12
  "source": [
13
- "from transformers import pipeline\n",
14
- "from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline, AutoFeatureExtractor\n",
 
 
 
 
 
 
 
15
  "import torchaudio\n",
16
- "import torchaudio.transforms as T"
 
 
 
 
17
  ],
 
18
  "outputs": [],
19
- "execution_count": 32
20
  },
21
  {
22
  "metadata": {
23
- "collapsed": true,
24
  "ExecuteTime": {
25
- "end_time": "2025-04-21T05:03:48.582040Z",
26
- "start_time": "2025-04-21T04:51:46.343821Z"
27
  }
28
  },
29
  "cell_type": "code",
 
 
30
  "outputs": [
31
- {
32
- "data": {
33
- "text/plain": [
34
- "model.safetensors: 0%| | 0.00/151M [00:00<?, ?B/s]"
35
- ],
36
- "application/vnd.jupyter.widget-view+json": {
37
- "version_major": 2,
38
- "version_minor": 0,
39
- "model_id": "51ffb4afb57446278c28d690aa1b22e4"
40
- }
41
- },
42
- "metadata": {},
43
- "output_type": "display_data"
44
- },
45
- {
46
- "data": {
47
- "text/plain": [
48
- "generation_config.json: 0%| | 0.00/3.75k [00:00<?, ?B/s]"
49
- ],
50
- "application/vnd.jupyter.widget-view+json": {
51
- "version_major": 2,
52
- "version_minor": 0,
53
- "model_id": "86143c7cd15341e39db1e81231d4fd7e"
54
- }
55
- },
56
- "metadata": {},
57
- "output_type": "display_data"
58
- },
59
- {
60
- "data": {
61
- "text/plain": [
62
- "tokenizer_config.json: 0%| | 0.00/283k [00:00<?, ?B/s]"
63
- ],
64
- "application/vnd.jupyter.widget-view+json": {
65
- "version_major": 2,
66
- "version_minor": 0,
67
- "model_id": "01d06664af1c4c169175cd38b00fa78e"
68
- }
69
- },
70
- "metadata": {},
71
- "output_type": "display_data"
72
- },
73
- {
74
- "data": {
75
- "text/plain": [
76
- "vocab.json: 0%| | 0.00/836k [00:00<?, ?B/s]"
77
- ],
78
- "application/vnd.jupyter.widget-view+json": {
79
- "version_major": 2,
80
- "version_minor": 0,
81
- "model_id": "8833df76fdf24e92bf51c748aa71bc48"
82
- }
83
- },
84
- "metadata": {},
85
- "output_type": "display_data"
86
- },
87
- {
88
- "data": {
89
- "text/plain": [
90
- "tokenizer.json: 0%| | 0.00/2.48M [00:00<?, ?B/s]"
91
- ],
92
- "application/vnd.jupyter.widget-view+json": {
93
- "version_major": 2,
94
- "version_minor": 0,
95
- "model_id": "3f5dfd342c574c2698f42c51a567a77e"
96
- }
97
- },
98
- "metadata": {},
99
- "output_type": "display_data"
100
- },
101
- {
102
- "data": {
103
- "text/plain": [
104
- "merges.txt: 0%| | 0.00/494k [00:00<?, ?B/s]"
105
- ],
106
- "application/vnd.jupyter.widget-view+json": {
107
- "version_major": 2,
108
- "version_minor": 0,
109
- "model_id": "e9363f15071d48878ef8230bd6c39177"
110
- }
111
- },
112
- "metadata": {},
113
- "output_type": "display_data"
114
- },
115
- {
116
- "data": {
117
- "text/plain": [
118
- "normalizer.json: 0%| | 0.00/52.7k [00:00<?, ?B/s]"
119
- ],
120
- "application/vnd.jupyter.widget-view+json": {
121
- "version_major": 2,
122
- "version_minor": 0,
123
- "model_id": "d0e777a74b9a47f3ad6a18254825122b"
124
- }
125
- },
126
- "metadata": {},
127
- "output_type": "display_data"
128
- },
129
- {
130
- "data": {
131
- "text/plain": [
132
- "added_tokens.json: 0%| | 0.00/34.6k [00:00<?, ?B/s]"
133
- ],
134
- "application/vnd.jupyter.widget-view+json": {
135
- "version_major": 2,
136
- "version_minor": 0,
137
- "model_id": "fd44e41876984d81a593d55e07e71be6"
138
- }
139
- },
140
- "metadata": {},
141
- "output_type": "display_data"
142
- },
143
- {
144
- "data": {
145
- "text/plain": [
146
- "special_tokens_map.json: 0%| | 0.00/2.19k [00:00<?, ?B/s]"
147
- ],
148
- "application/vnd.jupyter.widget-view+json": {
149
- "version_major": 2,
150
- "version_minor": 0,
151
- "model_id": "c3baafd24e1c4c4d9dcaf5a4715e846e"
152
- }
153
- },
154
- "metadata": {},
155
- "output_type": "display_data"
156
- },
157
- {
158
- "data": {
159
- "text/plain": [
160
- "preprocessor_config.json: 0%| | 0.00/185k [00:00<?, ?B/s]"
161
- ],
162
- "application/vnd.jupyter.widget-view+json": {
163
- "version_major": 2,
164
- "version_minor": 0,
165
- "model_id": "3a671889c0504b50bcff2aec93497d78"
166
- }
167
- },
168
- "metadata": {},
169
- "output_type": "display_data"
170
- },
171
  {
172
  "name": "stderr",
173
  "output_type": "stream",
174
  "text": [
 
 
 
 
175
  "Device set to use mps:0\n"
176
  ]
177
  }
178
  ],
179
- "execution_count": 4,
180
- "source": [
181
- "\n",
182
- "pipe = pipeline(model=\"openai/whisper-tiny\", task=\"automatic-speech-recognition\")\n"
183
- ],
184
- "id": "initial_id"
185
- },
186
- {
187
- "metadata": {},
188
- "cell_type": "code",
189
- "outputs": [],
190
- "execution_count": null,
191
- "source": [
192
- "# Load audio file\n",
193
- "waveform_1, sample_rate = torchaudio.load(\"sample.wav\")\n",
194
- "# Target sampling rate (e.g., 16000 Hz for Whisper)\n",
195
- "target_sr = 16000\n",
196
- "\n",
197
- "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
198
- "waveform = resampler(waveform_1)\n",
199
- "waveform_np = waveform.squeeze().numpy()\n",
200
- "\n",
201
- "print(waveform.shape) # (channels, samples) — usually (1, N)\n",
202
- "print(sample_rate)\n",
203
- "print(waveform_np)"
204
- ],
205
- "id": "dc202f529230fa87"
206
- },
207
- {
208
- "metadata": {
209
- "ExecuteTime": {
210
- "end_time": "2025-04-21T05:08:38.144954Z",
211
- "start_time": "2025-04-21T05:08:38.087644Z"
212
- }
213
- },
214
- "cell_type": "code",
215
- "source": [
216
- "save_dir = \"./models_for_proj/whisper-tiny\"\n",
217
- "device = 'cpu'\n",
218
- "pipe.generation_config.save_pretrained(save_dir)\n",
219
- "pipe.tokenizer.save_pretrained(save_dir)\n",
220
- "pipe.feature_extractor.save_pretrained(save_dir)\n"
221
- ],
222
- "id": "ed09605af0b78939",
223
- "outputs": [
224
- {
225
- "data": {
226
- "text/plain": [
227
- "['./models_for_proj/whisper-tiny/preprocessor_config.json']"
228
- ]
229
- },
230
- "execution_count": 6,
231
- "metadata": {},
232
- "output_type": "execute_result"
233
- }
234
- ],
235
- "execution_count": 6
236
- },
237
- {
238
- "metadata": {
239
- "ExecuteTime": {
240
- "end_time": "2025-04-21T05:35:59.540770Z",
241
- "start_time": "2025-04-21T05:35:59.476164Z"
242
- }
243
- },
244
- "cell_type": "code",
245
- "source": [
246
- "\n",
247
- "# model = AutoModelForSpeechSeq2Seq.from_pretrained(save_dir, device=device)\n",
248
- "# model.config.forced_decoder_ids = None\n",
249
- "# processor = AutoProcessor.from_pretrained(save_dir, device=device)\n",
250
- "# tokenizer = AutoTokenizer.from_pretrained(save_dir, device=device)\n",
251
- "# feature_extractor = AutoFeatureExtractor.from_pretrained(save_dir, device=device)\n",
252
- "# pipe = pipeline(\"automatic-speech-recognition\", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
253
- "# result = pipe(\"sample.wav\")\n",
254
- "# result[\"text\"]"
255
- ],
256
- "id": "1dcd38e5ca08781b",
257
- "outputs": [
258
- {
259
- "ename": "TypeError",
260
- "evalue": "WhisperForConditionalGeneration.__init__() got an unexpected keyword argument 'device'",
261
- "output_type": "error",
262
- "traceback": [
263
- "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
264
- "\u001B[31mTypeError\u001B[39m Traceback (most recent call last)",
265
- "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[30]\u001B[39m\u001B[32m, line 3\u001B[39m\n\u001B[32m 1\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtransformers\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline, AutoFeatureExtractor\n\u001B[32m 2\u001B[39m device = \u001B[33m'\u001B[39m\u001B[33mcpu\u001B[39m\u001B[33m'\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m3\u001B[39m model = \u001B[43mAutoModelForSpeechSeq2Seq\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\u001B[43msave_dir\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4\u001B[39m model.config.forced_decoder_ids = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 5\u001B[39m processor = AutoProcessor.from_pretrained(save_dir, device=device)\n",
266
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py:573\u001B[39m, in \u001B[36m_BaseAutoModelClass.from_pretrained\u001B[39m\u001B[34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001B[39m\n\u001B[32m 571\u001B[39m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;28mtype\u001B[39m(config) \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mcls\u001B[39m._model_mapping.keys():\n\u001B[32m 572\u001B[39m model_class = _get_model_class(config, \u001B[38;5;28mcls\u001B[39m._model_mapping)\n\u001B[32m--> \u001B[39m\u001B[32m573\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mmodel_class\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 574\u001B[39m \u001B[43m \u001B[49m\u001B[43mpretrained_model_name_or_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43mmodel_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m=\u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mhub_kwargs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\n\u001B[32m 575\u001B[39m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 576\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[32m 577\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33mUnrecognized configuration class \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mconfig.\u001B[34m__class__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m for this kind of AutoModel: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mcls\u001B[39m.\u001B[34m__name__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m.\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[33m\"\u001B[39m\n\u001B[32m 578\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33mModel type should be one of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[33m'\u001B[39m\u001B[33m, \u001B[39m\u001B[33m'\u001B[39m.join(c.\u001B[34m__name__\u001B[39m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mfor\u001B[39;00m\u001B[38;5;250m \u001B[39mc\u001B[38;5;250m \u001B[39m\u001B[38;5;129;01min\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28mcls\u001B[39m._model_mapping.keys())\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m.\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 579\u001B[39m )\n",
267
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:272\u001B[39m, in \u001B[36mrestore_default_torch_dtype.<locals>._wrapper\u001B[39m\u001B[34m(*args, **kwargs)\u001B[39m\n\u001B[32m 270\u001B[39m old_dtype = torch.get_default_dtype()\n\u001B[32m 271\u001B[39m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[32m--> \u001B[39m\u001B[32m272\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 273\u001B[39m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[32m 274\u001B[39m torch.set_default_dtype(old_dtype)\n",
268
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:4401\u001B[39m, in \u001B[36mPreTrainedModel.from_pretrained\u001B[39m\u001B[34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001B[39m\n\u001B[32m 4395\u001B[39m config = \u001B[38;5;28mcls\u001B[39m._autoset_attn_implementation(\n\u001B[32m 4396\u001B[39m config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map\n\u001B[32m 4397\u001B[39m )\n\u001B[32m 4399\u001B[39m \u001B[38;5;28;01mwith\u001B[39;00m ContextManagers(model_init_context):\n\u001B[32m 4400\u001B[39m \u001B[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m4401\u001B[39m model = \u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43mmodel_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4403\u001B[39m \u001B[38;5;66;03m# Make sure to tie the weights correctly\u001B[39;00m\n\u001B[32m 4404\u001B[39m model.tie_weights()\n",
269
- "\u001B[31mTypeError\u001B[39m: WhisperForConditionalGeneration.__init__() got an unexpected keyword argument 'device'"
270
- ]
271
- }
272
- ],
273
- "execution_count": 30
274
  },
275
  {
276
  "metadata": {
277
  "ExecuteTime": {
278
- "end_time": "2025-04-21T06:13:00.420733Z",
279
- "start_time": "2025-04-21T06:13:00.033330Z"
280
  }
281
  },
282
  "cell_type": "code",
283
  "source": [
284
- "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
285
- "# load dummy dataset and read audio files\n",
286
  "\n",
287
- "# input\n",
288
  "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
289
  "target_sr = 16000\n",
290
  "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
291
  "waveform = resampler(waveform)\n",
292
  "waveform_np = waveform.squeeze().numpy()\n",
 
293
  "\n",
294
- "\n",
295
- "processor = WhisperProcessor.from_pretrained(save_dir)\n",
296
- "model = WhisperForConditionalGeneration.from_pretrained(save_dir)\n",
297
- "model.config.forced_decoder_ids = None\n",
298
- "\n",
299
- "input_features = processor(waveform_np, sampling_rate=target_sr, return_tensors=\"pt\", device=device).input_features\n",
300
- "\n",
301
- "# generate token ids\n",
302
- "predicted_ids = model.generate(input_features)\n",
303
- "# decode token ids to text\n",
304
- "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
305
- "# ['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>']\n",
306
- "print(transcription)\n",
307
- "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
308
- "# [' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']\n",
309
- "print(transcription)"
310
  ],
311
- "id": "b0865456fed26d31",
312
  "outputs": [
313
  {
314
- "name": "stderr",
315
  "output_type": "stream",
316
  "text": [
317
- "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
318
- ]
319
- },
320
- {
321
- "ename": "ValueError",
322
- "evalue": "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument in favour of `input_ids` or `decoder_input_ids` respectively.",
323
- "output_type": "error",
324
- "traceback": [
325
- "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
326
- "\u001B[31mValueError\u001B[39m Traceback (most recent call last)",
327
- "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[34]\u001B[39m\u001B[32m, line 19\u001B[39m\n\u001B[32m 16\u001B[39m input_features = processor(waveform_np, sampling_rate=target_sr, return_tensors=\u001B[33m\"\u001B[39m\u001B[33mpt\u001B[39m\u001B[33m\"\u001B[39m, device=device).input_features\n\u001B[32m 18\u001B[39m \u001B[38;5;66;03m# generate token ids\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m19\u001B[39m predicted_ids = \u001B[43mmodel\u001B[49m\u001B[43m.\u001B[49m\u001B[43mgenerate\u001B[49m\u001B[43m(\u001B[49m\u001B[43minput_features\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 20\u001B[39m \u001B[38;5;66;03m# decode token ids to text\u001B[39;00m\n\u001B[32m 21\u001B[39m transcription = processor.batch_decode(predicted_ids, skip_special_tokens=\u001B[38;5;28;01mFalse\u001B[39;00m)\n",
328
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/generation_whisper.py:774\u001B[39m, in \u001B[36mWhisperGenerationMixin.generate\u001B[39m\u001B[34m(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, time_precision_features, return_token_timestamps, return_segments, return_dict_in_generate, force_unique_generate_call, **kwargs)\u001B[39m\n\u001B[32m 765\u001B[39m proc.set_begin_index(decoder_input_ids.shape[-\u001B[32m1\u001B[39m])\n\u001B[32m 767\u001B[39m \u001B[38;5;66;03m# 6.6 Run generate with fallback\u001B[39;00m\n\u001B[32m 768\u001B[39m (\n\u001B[32m 769\u001B[39m seek_sequences,\n\u001B[32m 770\u001B[39m seek_outputs,\n\u001B[32m 771\u001B[39m should_skip,\n\u001B[32m 772\u001B[39m do_condition_on_prev_tokens,\n\u001B[32m 773\u001B[39m model_output_type,\n\u001B[32m--> \u001B[39m\u001B[32m774\u001B[39m ) = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mgenerate_with_fallback\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 775\u001B[39m \u001B[43m \u001B[49m\u001B[43msegment_input\u001B[49m\u001B[43m=\u001B[49m\u001B[43msegment_input\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 776\u001B[39m \u001B[43m \u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 777\u001B[39m \u001B[43m \u001B[49m\u001B[43mcur_bsz\u001B[49m\u001B[43m=\u001B[49m\u001B[43mcur_bsz\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 778\u001B[39m \u001B[43m \u001B[49m\u001B[43mbatch_idx_map\u001B[49m\u001B[43m=\u001B[49m\u001B[43mbatch_idx_map\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 779\u001B[39m \u001B[43m \u001B[49m\u001B[43mseek\u001B[49m\u001B[43m=\u001B[49m\u001B[43mseek\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 780\u001B[39m \u001B[43m \u001B[49m\u001B[43mnum_segment_frames\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnum_segment_frames\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 781\u001B[39m \u001B[43m \u001B[49m\u001B[43mmax_frames\u001B[49m\u001B[43m=\u001B[49m\u001B[43mmax_frames\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 782\u001B[39m \u001B[43m \u001B[49m\u001B[43mtemperatures\u001B[49m\u001B[43m=\u001B[49m\u001B[43mtemperatures\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 783\u001B[39m \u001B[43m \u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m=\u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 784\u001B[39m \u001B[43m \u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m=\u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 785\u001B[39m \u001B[43m \u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m=\u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 786\u001B[39m \u001B[43m \u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m=\u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 787\u001B[39m \u001B[43m \u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m=\u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 788\u001B[39m \u001B[43m \u001B[49m\u001B[43mreturn_token_timestamps\u001B[49m\u001B[43m=\u001B[49m\u001B[43mreturn_token_timestamps\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 789\u001B[39m \u001B[43m \u001B[49m\u001B[43mdo_condition_on_prev_tokens\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdo_condition_on_prev_tokens\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 790\u001B[39m \u001B[43m \u001B[49m\u001B[43mis_shortform\u001B[49m\u001B[43m=\u001B[49m\u001B[43mis_shortform\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 791\u001B[39m \u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[43m=\u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 792\u001B[39m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m=\u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 793\u001B[39m \u001B[43m \u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m=\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 794\u001B[39m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 796\u001B[39m \u001B[38;5;66;03m# 6.7 In every generated sequence, split by timestamp tokens and extract segments\u001B[39;00m\n\u001B[32m 797\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m i, seek_sequence \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(seek_sequences):\n",
329
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/generation_whisper.py:950\u001B[39m, in \u001B[36mWhisperGenerationMixin.generate_with_fallback\u001B[39m\u001B[34m(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, is_shortform, batch_size, attention_mask, kwargs)\u001B[39m\n\u001B[32m 945\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m generate_kwargs.get(\u001B[33m\"\u001B[39m\u001B[33mencoder_outputs\u001B[39m\u001B[33m\"\u001B[39m) \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m 946\u001B[39m generate_kwargs[\u001B[33m\"\u001B[39m\u001B[33mencoder_outputs\u001B[39m\u001B[33m\"\u001B[39m] = F.pad(\n\u001B[32m 947\u001B[39m generate_kwargs[\u001B[33m\"\u001B[39m\u001B[33mencoder_outputs\u001B[39m\u001B[33m\"\u001B[39m], (\u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, batch_size - cur_bsz), value=\u001B[32m0\u001B[39m\n\u001B[32m 948\u001B[39m )\n\u001B[32m--> \u001B[39m\u001B[32m950\u001B[39m seek_outputs = \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m.\u001B[49m\u001B[43mgenerate\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 951\u001B[39m \u001B[43m \u001B[49m\u001B[43msegment_input\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 952\u001B[39m \u001B[43m \u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m=\u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 953\u001B[39m \u001B[43m \u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m=\u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 954\u001B[39m \u001B[43m \u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m=\u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 955\u001B[39m \u001B[43m \u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m=\u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 956\u001B[39m \u001B[43m \u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m=\u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 957\u001B[39m \u001B[43m \u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 958\u001B[39m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m=\u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 959\u001B[39m \u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mgenerate_kwargs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 960\u001B[39m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 962\u001B[39m model_output_type = \u001B[38;5;28mtype\u001B[39m(seek_outputs)\n\u001B[32m 964\u001B[39m \u001B[38;5;66;03m# post-process sequence tokens and outputs to be in list form\u001B[39;00m\n",
330
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001B[39m, in \u001B[36mcontext_decorator.<locals>.decorate_context\u001B[39m\u001B[34m(*args, **kwargs)\u001B[39m\n\u001B[32m 113\u001B[39m \u001B[38;5;129m@functools\u001B[39m.wraps(func)\n\u001B[32m 114\u001B[39m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34mdecorate_context\u001B[39m(*args, **kwargs):\n\u001B[32m 115\u001B[39m \u001B[38;5;28;01mwith\u001B[39;00m ctx_factory():\n\u001B[32m--> \u001B[39m\u001B[32m116\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
331
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:2219\u001B[39m, in \u001B[36mGenerationMixin.generate\u001B[39m\u001B[34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)\u001B[39m\n\u001B[32m 2208\u001B[39m warnings.warn(\n\u001B[32m 2209\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mYou are calling .generate() with the `input_ids` being on a device type different\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 2210\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33m than your model\u001B[39m\u001B[33m'\u001B[39m\u001B[33ms device. `input_ids` is on \u001B[39m\u001B[38;5;132;01m{\u001B[39;00minput_ids.device.type\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m, whereas the model\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m (...)\u001B[39m\u001B[32m 2215\u001B[39m \u001B[38;5;167;01mUserWarning\u001B[39;00m,\n\u001B[32m 2216\u001B[39m )\n\u001B[32m 2218\u001B[39m \u001B[38;5;66;03m# 9. prepare logits processors and stopping criteria\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m2219\u001B[39m prepared_logits_processor = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_get_logits_processor\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 2220\u001B[39m \u001B[43m \u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m=\u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2221\u001B[39m \u001B[43m \u001B[49m\u001B[43minput_ids_seq_length\u001B[49m\u001B[43m=\u001B[49m\u001B[43minput_ids_length\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2222\u001B[39m \u001B[43m \u001B[49m\u001B[43mencoder_input_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43minputs_tensor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2223\u001B[39m \u001B[43m \u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m=\u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2224\u001B[39m \u001B[43m \u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m=\u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2225\u001B[39m \u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m=\u001B[49m\u001B[43minputs_tensor\u001B[49m\u001B[43m.\u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2226\u001B[39m \u001B[43m \u001B[49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m=\u001B[49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2227\u001B[39m \u001B[43m \u001B[49m\u001B[43mnegative_prompt_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnegative_prompt_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2228\u001B[39m \u001B[43m \u001B[49m\u001B[43mnegative_prompt_attention_mask\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnegative_prompt_attention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2229\u001B[39m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 2230\u001B[39m prepared_stopping_criteria = \u001B[38;5;28mself\u001B[39m._get_stopping_criteria(\n\u001B[32m 2231\u001B[39m generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs\n\u001B[32m 2232\u001B[39m )\n\u001B[32m 2234\u001B[39m \u001B[38;5;66;03m# Set model_kwargs `use_cache` so we can use it later in forward runs\u001B[39;00m\n",
332
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:1083\u001B[39m, in \u001B[36mGenerationMixin._get_logits_processor\u001B[39m\u001B[34m(self, generation_config, input_ids_seq_length, encoder_input_ids, prefix_allowed_tokens_fn, logits_processor, device, model_kwargs, negative_prompt_ids, negative_prompt_attention_mask)\u001B[39m\n\u001B[32m 1074\u001B[39m processors.append(\n\u001B[32m 1075\u001B[39m SuppressTokensAtBeginLogitsProcessor(\n\u001B[32m 1076\u001B[39m generation_config.begin_suppress_tokens,\n\u001B[32m (...)\u001B[39m\u001B[32m 1079\u001B[39m )\n\u001B[32m 1080\u001B[39m )\n\u001B[32m 1081\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m generation_config.forced_decoder_ids \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m 1082\u001B[39m \u001B[38;5;66;03m# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m1083\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[32m 1084\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mYou have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument \u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 1085\u001B[39m \u001B[33m\"\u001B[39m\u001B[33min favour of `input_ids` or `decoder_input_ids` respectively.\u001B[39m\u001B[33m\"\u001B[39m,\n\u001B[32m 1086\u001B[39m )\n\u001B[32m 1088\u001B[39m \u001B[38;5;66;03m# TODO (joao): find a strategy to specify the order of the processors\u001B[39;00m\n\u001B[32m 1089\u001B[39m processors = \u001B[38;5;28mself\u001B[39m._merge_criteria_processor_list(processors, logits_processor)\n",
333
- "\u001B[31mValueError\u001B[39m: You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument in favour of `input_ids` or `decoder_input_ids` respectively."
334
  ]
335
  }
336
  ],
337
- "execution_count": 34
338
  },
339
  {
340
- "metadata": {
341
- "ExecuteTime": {
342
- "end_time": "2025-04-21T06:15:41.079099Z",
343
- "start_time": "2025-04-21T06:15:37.277194Z"
344
- }
345
- },
346
- "cell_type": "code",
347
- "source": [
348
- "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
349
- "from datasets import load_dataset\n",
350
- "\n",
351
- "# load model and processor\n",
352
- "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
353
- "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
354
- "model.config.forced_decoder_ids = None\n",
355
- "\n",
356
- "# load dummy dataset and read audio files\n",
357
- "ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
358
- "sample = ds[0][\"audio\"]\n",
359
- "input_features = processor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"], return_tensors=\"pt\").input_features\n",
360
- "\n",
361
- "# generate token ids\n",
362
- "predicted_ids = model.generate(input_features)\n",
363
- "# decode token ids to text\n",
364
- "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
365
- "processor(transcription)\n",
366
- "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
367
- "processor(transcription)\n"
368
- ],
369
- "id": "b4137e08d1a516e5",
370
- "outputs": [
371
- {
372
- "name": "stderr",
373
- "output_type": "stream",
374
- "text": [
375
- "It is strongly recommended to pass the `sampling_rate` argument to `WhisperFeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.\n"
376
- ]
377
- },
378
- {
379
- "ename": "ValueError",
380
- "evalue": "could not convert string to float: ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'",
381
- "output_type": "error",
382
- "traceback": [
383
- "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
384
- "\u001B[31mValueError\u001B[39m Traceback (most recent call last)",
385
- "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[37]\u001B[39m\u001B[32m, line 18\u001B[39m\n\u001B[32m 16\u001B[39m \u001B[38;5;66;03m# decode token ids to text\u001B[39;00m\n\u001B[32m 17\u001B[39m transcription = processor.batch_decode(predicted_ids, skip_special_tokens=\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[32m---> \u001B[39m\u001B[32m18\u001B[39m \u001B[43mprocessor\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtranscription\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 19\u001B[39m transcription = processor.batch_decode(predicted_ids, skip_special_tokens=\u001B[38;5;28;01mTrue\u001B[39;00m)\n\u001B[32m 20\u001B[39m processor(transcription)\n",
386
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/processing_whisper.py:69\u001B[39m, in \u001B[36mWhisperProcessor.__call__\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 66\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[33m\"\u001B[39m\u001B[33mYou need to specify either an `audio` or `text` input to process.\u001B[39m\u001B[33m\"\u001B[39m)\n\u001B[32m 68\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m audio \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m---> \u001B[39m\u001B[32m69\u001B[39m inputs = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mfeature_extractor\u001B[49m\u001B[43m(\u001B[49m\u001B[43maudio\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msampling_rate\u001B[49m\u001B[43m=\u001B[49m\u001B[43msampling_rate\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 70\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m text \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m 71\u001B[39m encodings = \u001B[38;5;28mself\u001B[39m.tokenizer(text, **kwargs)\n",
387
- "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/feature_extraction_whisper.py:281\u001B[39m, in \u001B[36mWhisperFeatureExtractor.__call__\u001B[39m\u001B[34m(self, raw_speech, truncation, pad_to_multiple_of, return_tensors, return_attention_mask, padding, max_length, sampling_rate, do_normalize, device, return_token_timestamps, **kwargs)\u001B[39m\n\u001B[32m 279\u001B[39m raw_speech = [np.asarray([speech], dtype=np.float32).T \u001B[38;5;28;01mfor\u001B[39;00m speech \u001B[38;5;129;01min\u001B[39;00m raw_speech]\n\u001B[32m 280\u001B[39m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m is_batched \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(raw_speech, np.ndarray):\n\u001B[32m--> \u001B[39m\u001B[32m281\u001B[39m raw_speech = \u001B[43mnp\u001B[49m\u001B[43m.\u001B[49m\u001B[43masarray\u001B[49m\u001B[43m(\u001B[49m\u001B[43mraw_speech\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdtype\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnp\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfloat32\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 282\u001B[39m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(raw_speech, np.ndarray) \u001B[38;5;129;01mand\u001B[39;00m raw_speech.dtype \u001B[38;5;129;01mis\u001B[39;00m np.dtype(np.float64):\n\u001B[32m 283\u001B[39m raw_speech = raw_speech.astype(np.float32)\n",
388
- "\u001B[31mValueError\u001B[39m: could not convert string to float: ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'"
389
- ]
390
- }
391
- ],
392
- "execution_count": 37
393
  },
394
  {
395
  "metadata": {
396
  "ExecuteTime": {
397
- "end_time": "2025-04-21T05:31:26.352787Z",
398
- "start_time": "2025-04-21T05:31:26.343398Z"
399
  }
400
  },
401
  "cell_type": "code",
402
- "source": "",
403
- "id": "37fa63b1c22f4a69",
404
- "outputs": [
405
- {
406
- "name": "stdout",
407
- "output_type": "stream",
408
- "text": [
409
- "torch.Size([1, 24192])\n",
410
- "24000\n",
411
- "[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... -1.3932839e-05\n",
412
- " -3.6663318e-05 -1.3932839e-05]\n"
413
- ]
414
- }
415
- ],
416
- "execution_count": 25
417
  },
418
  {
419
  "metadata": {
420
  "ExecuteTime": {
421
- "end_time": "2025-04-21T06:28:40.294060Z",
422
- "start_time": "2025-04-21T06:28:35.493462Z"
423
  }
424
  },
425
  "cell_type": "code",
426
  "source": [
427
- "import torch\n",
428
- "from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\n",
429
- "from datasets import load_dataset\n",
430
- "\n",
431
- "\n",
432
- "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
433
- "torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n",
434
- "\n",
435
- "# model_id = \"distil-whisper/distil-small.en\"\n",
436
- "model_id = \"./models_for_proj/librispeech_asr_dummy\"\n",
437
- "\n",
438
- "model = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
439
- " model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True\n",
440
- ")\n",
441
- "model.to(device)\n",
442
- "\n",
443
- "processor = AutoProcessor.from_pretrained(model_id)\n",
444
- "\n",
445
- "pipe = pipeline(\n",
446
- " \"automatic-speech-recognition\",\n",
447
- " model=model,\n",
448
- " tokenizer=processor.tokenizer,\n",
449
- " feature_extractor=processor.feature_extractor,\n",
450
- " max_new_tokens=128,\n",
451
- " torch_dtype=torch_dtype,\n",
452
- " device=device,\n",
453
- ")\n",
454
- "\n",
455
- "# dataset = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
456
- "# sample = dataset[0][\"audio\"]\n",
457
- "# result = pipe(sample)\n",
458
  "\n",
459
- "# input\n",
460
- "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
461
- "target_sr = 16000\n",
462
- "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
463
- "waveform = resampler(waveform)\n",
464
- "waveform_np = waveform.squeeze().numpy()\n",
465
- "# sample = dataset[2][\"audio\"]\n",
466
  "\n",
467
- "# result = pipe(sample)\n",
468
- "result = pipe(waveform_np)\n",
469
- "print(result[\"text\"])\n"
470
- ],
471
- "id": "e7f0a5bccb4e204f",
472
- "outputs": [
473
- {
474
- "name": "stderr",
475
- "output_type": "stream",
476
- "text": [
477
- "Device set to use cpu\n",
478
- "/Users/perchik/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/pipelines/automatic_speech_recognition.py:312: FutureWarning: `max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.\n",
479
- " warnings.warn(\n",
480
- "/Users/perchik/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/generation_whisper.py:573: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.\n",
481
- " warnings.warn(\n",
482
- "`generation_config` default values have been modified to match model-specific defaults: {'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}. If this is not desired, please set these values explicitly.\n",
483
- "A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> to see related `.generate()` flags.\n",
484
- "A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> to see related `.generate()` flags.\n"
485
- ]
486
- },
487
- {
488
- "name": "stdout",
489
- "output_type": "stream",
490
- "text": [
491
- " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his Gospel.\n"
492
- ]
493
- }
494
  ],
495
- "execution_count": 46
496
- },
497
- {
498
- "metadata": {
499
- "ExecuteTime": {
500
- "end_time": "2025-04-21T06:27:16.239153Z",
501
- "start_time": "2025-04-21T06:27:15.587609Z"
502
- }
503
- },
504
- "cell_type": "code",
505
- "source": [
506
- "save_dir = \"./models_for_proj/librispeech_asr_dummy\"\n",
507
- "pipe.model.save_pretrained(save_dir)\n",
508
- "pipe.tokenizer.save_pretrained(save_dir)\n",
509
- "pipe.feature_extractor.save_pretrained(save_dir)"
510
- ],
511
- "id": "81b57090829a7294",
512
  "outputs": [
513
  {
514
  "name": "stderr",
515
  "output_type": "stream",
516
  "text": [
517
- "/Users/perchik/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:3353: UserWarning: Moving the following attributes in the config to the generation config: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361]}. You are seeing this warning because you've set generation parameters in the model config, as opposed to in the generation config.\n",
518
- " warnings.warn(\n"
519
  ]
520
  },
521
  {
522
  "data": {
523
  "text/plain": [
524
- "['./models_for_proj/librispeech_asr_dummy/preprocessor_config.json']"
525
  ]
526
  },
527
- "execution_count": 45,
528
  "metadata": {},
529
  "output_type": "execute_result"
530
  }
531
  ],
532
- "execution_count": 45
533
  },
534
  {
535
- "metadata": {
536
- "ExecuteTime": {
537
- "end_time": "2025-04-21T05:31:45.237137Z",
538
- "start_time": "2025-04-21T05:31:45.234474Z"
539
- }
540
- },
541
- "cell_type": "code",
542
- "source": "target_sr",
543
- "id": "61b31c4b81fd098f",
544
- "outputs": [
545
- {
546
- "data": {
547
- "text/plain": [
548
- "16000"
549
- ]
550
- },
551
- "execution_count": 26,
552
- "metadata": {},
553
- "output_type": "execute_result"
554
- }
555
- ],
556
- "execution_count": 26
557
  },
558
  {
559
  "metadata": {
560
  "ExecuteTime": {
561
- "end_time": "2025-04-21T11:20:26.931270Z",
562
- "start_time": "2025-04-21T11:20:24.762498Z"
563
  }
564
  },
565
  "cell_type": "code",
566
  "source": [
567
- "# input\n",
 
 
568
  "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
569
  "target_sr = 16000\n",
570
  "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
571
  "waveform = resampler(waveform)\n",
572
- "waveform_np = waveform.squeeze().numpy()\n",
573
- "# sample = dataset[2][\"audio\"]\n",
574
- "\n",
575
- "# result = pipe(sample)\n",
576
- "result = pipe(waveform_np)\n",
577
- "print(result[\"text\"])"
578
  ],
579
- "id": "5c9f9ff839e346f8",
580
- "outputs": [
581
- {
582
- "name": "stdout",
583
- "output_type": "stream",
584
- "text": [
585
- " This is a simple text.\n"
586
- ]
587
- }
588
- ],
589
- "execution_count": 48
590
  },
591
  {
592
  "metadata": {
593
  "ExecuteTime": {
594
- "end_time": "2025-04-21T11:54:49.800197Z",
595
- "start_time": "2025-04-21T11:54:47.143900Z"
596
  }
597
  },
598
  "cell_type": "code",
599
  "source": [
600
- "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC\n",
601
- "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
602
- "model = Wav2Vec2ForCTC.from_pretrained(\"facebook/wav2vec2-base-960h\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  ],
604
- "id": "a7084d040f38e0f5",
605
  "outputs": [
606
  {
607
- "name": "stderr",
608
  "output_type": "stream",
609
  "text": [
610
- "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']\n",
611
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
612
  ]
613
  }
614
  ],
615
- "execution_count": 49
616
  },
617
  {
618
  "metadata": {},
@@ -620,7 +231,7 @@
620
  "outputs": [],
621
  "execution_count": null,
622
  "source": "",
623
- "id": "f886807e783c9532"
624
  }
625
  ],
626
  "metadata": {
 
1
  {
2
  "cells": [
3
  {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": "## FIRST CHECK",
7
+ "id": "518bcf10bfff3063"
8
+ },
9
+ {
10
  "metadata": {
11
  "ExecuteTime": {
12
+ "end_time": "2025-04-21T15:45:34.883735Z",
13
+ "start_time": "2025-04-21T15:45:33.734296Z"
14
  }
15
  },
16
+ "cell_type": "code",
17
  "source": [
18
+ "# gradio app.py --watch-dirs app.py\n",
19
+ "\n",
20
+ "import gradio as gr\n",
21
+ "import numpy as np\n",
22
+ "import matplotlib.pyplot as plt\n",
23
+ "import matplotlib.animation as animation\n",
24
+ "import tempfile\n",
25
+ "import torch\n",
26
+ "from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\n",
27
  "import torchaudio\n",
28
+ "import torchaudio.transforms as T\n",
29
+ "from matplotlib.patches import Circle\n",
30
+ "from stable_baselines3 import SAC\n",
31
+ "from warehouse_env import WarehouseEnv\n",
32
+ "from types import SimpleNamespace"
33
  ],
34
+ "id": "f861a8e81b92bceb",
35
  "outputs": [],
36
+ "execution_count": 50
37
  },
38
  {
39
  "metadata": {
 
40
  "ExecuteTime": {
41
+ "end_time": "2025-04-21T15:45:58.508916Z",
42
+ "start_time": "2025-04-21T15:45:53.686659Z"
43
  }
44
  },
45
  "cell_type": "code",
46
+ "source": "asr_pipe_default = pipeline(\"automatic-speech-recognition\")",
47
+ "id": "90ddbbf24fac7b1f",
48
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  {
50
  "name": "stderr",
51
  "output_type": "stream",
52
  "text": [
53
+ "No model was supplied, defaulted to facebook/wav2vec2-base-960h and revision 22aad52 (https://huggingface.co/facebook/wav2vec2-base-960h).\n",
54
+ "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
55
+ "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']\n",
56
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
57
  "Device set to use mps:0\n"
58
  ]
59
  }
60
  ],
61
+ "execution_count": 51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  },
63
  {
64
  "metadata": {
65
  "ExecuteTime": {
66
+ "end_time": "2025-04-21T15:46:03.873405Z",
67
+ "start_time": "2025-04-21T15:46:02.219145Z"
68
  }
69
  },
70
  "cell_type": "code",
71
  "source": [
 
 
72
  "\n",
 
73
  "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
74
  "target_sr = 16000\n",
75
  "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
76
  "waveform = resampler(waveform)\n",
77
  "waveform_np = waveform.squeeze().numpy()\n",
78
+ "# sample = dataset[2][\"audio\"]\n",
79
  "\n",
80
+ "# result = pipe(sample)\n",
81
+ "result = asr_pipe_default(waveform_np)\n",
82
+ "print(result[\"text\"])\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  ],
84
+ "id": "75dbfd85403eb511",
85
  "outputs": [
86
  {
87
+ "name": "stdout",
88
  "output_type": "stream",
89
  "text": [
90
+ "THIS IS A SIMPLE TEXT\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ]
92
  }
93
  ],
94
+ "execution_count": 52
95
  },
96
  {
97
+ "metadata": {},
98
+ "cell_type": "markdown",
99
+ "source": "## TO SAVE THE MODEL",
100
+ "id": "e0a9c2fd7bce280a"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  },
102
  {
103
  "metadata": {
104
  "ExecuteTime": {
105
+ "end_time": "2025-04-21T15:51:20.114613Z",
106
+ "start_time": "2025-04-21T15:51:20.106995Z"
107
  }
108
  },
109
  "cell_type": "code",
110
+ "source": "save_dir = './models_for_proj/wav2vec2-base-960h'",
111
+ "id": "10f2808d5da846b9",
112
+ "outputs": [],
113
+ "execution_count": 53
 
 
 
 
 
 
 
 
 
 
 
114
  },
115
  {
116
  "metadata": {
117
  "ExecuteTime": {
118
+ "end_time": "2025-04-21T15:54:16.050333Z",
119
+ "start_time": "2025-04-21T15:54:12.432304Z"
120
  }
121
  },
122
  "cell_type": "code",
123
  "source": [
124
+ "from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  "\n",
126
+ "# Load pretrained model and processor\n",
127
+ "model = Wav2Vec2ForCTC.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
128
+ "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
 
 
 
 
129
  "\n",
130
+ "# Save locally\n",
131
+ "model.save_pretrained(save_dir)\n",
132
+ "processor.save_pretrained(save_dir)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  ],
134
+ "id": "c22c64edf17613a0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  "outputs": [
136
  {
137
  "name": "stderr",
138
  "output_type": "stream",
139
  "text": [
140
+ "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']\n",
141
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
142
  ]
143
  },
144
  {
145
  "data": {
146
  "text/plain": [
147
+ "[]"
148
  ]
149
  },
150
+ "execution_count": 57,
151
  "metadata": {},
152
  "output_type": "execute_result"
153
  }
154
  ],
155
+ "execution_count": 57
156
  },
157
  {
158
+ "metadata": {},
159
+ "cell_type": "markdown",
160
+ "source": "## TO REUSE IT",
161
+ "id": "b2e0767904efbbb3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  },
163
  {
164
  "metadata": {
165
  "ExecuteTime": {
166
+ "end_time": "2025-04-21T15:59:35.714597Z",
167
+ "start_time": "2025-04-21T15:59:35.705418Z"
168
  }
169
  },
170
  "cell_type": "code",
171
  "source": [
172
+ "import torchaudio\n",
173
+ "import torchaudio.transforms as T\n",
174
+ "\n",
175
  "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
176
  "target_sr = 16000\n",
177
  "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
178
  "waveform = resampler(waveform)\n",
179
+ "waveform_np = waveform.squeeze().numpy()"
 
 
 
 
 
180
  ],
181
+ "id": "394c5b342a6510",
182
+ "outputs": [],
183
+ "execution_count": 61
 
 
 
 
 
 
 
 
184
  },
185
  {
186
  "metadata": {
187
  "ExecuteTime": {
188
+ "end_time": "2025-04-21T15:59:36.498222Z",
189
+ "start_time": "2025-04-21T15:59:36.361763Z"
190
  }
191
  },
192
  "cell_type": "code",
193
  "source": [
194
+ "import torch\n",
195
+ "from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
196
+ "\n",
197
+ "save_dir = './models_for_proj/wav2vec2-base-960h'\n",
198
+ "\n",
199
+ "# load\n",
200
+ "model = Wav2Vec2ForCTC.from_pretrained(save_dir)\n",
201
+ "processor = Wav2Vec2Processor.from_pretrained(save_dir)\n",
202
+ "\n",
203
+ "# Preprocess\n",
204
+ "inputs = processor(waveform_np, sampling_rate=16000, return_tensors=\"pt\", padding=True)\n",
205
+ "\n",
206
+ "# Inference\n",
207
+ "with torch.no_grad():\n",
208
+ " logits = model(**inputs).logits\n",
209
+ "\n",
210
+ "# Decode\n",
211
+ "predicted_ids = torch.argmax(logits, dim=-1)\n",
212
+ "transcription = processor.decode(predicted_ids[0])\n",
213
+ "\n",
214
+ "print(\"Transcription:\", transcription)\n"
215
  ],
216
+ "id": "af430cf9e1e42318",
217
  "outputs": [
218
  {
219
+ "name": "stdout",
220
  "output_type": "stream",
221
  "text": [
222
+ "Transcription: THIS IS A SIMPLE TEXT\n"
 
223
  ]
224
  }
225
  ],
226
+ "execution_count": 62
227
  },
228
  {
229
  "metadata": {},
 
231
  "outputs": [],
232
  "execution_count": null,
233
  "source": "",
234
+ "id": "113500626c003f89"
235
  }
236
  ],
237
  "metadata": {
models_for_proj/wav2vec2-base-960h/config.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "adapter_attn_dim": null,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 256,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": false,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "group",
52
+ "feat_proj_dropout": 0.1,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.1,
55
+ "gradient_checkpointing": false,
56
+ "hidden_act": "gelu",
57
+ "hidden_dropout": 0.1,
58
+ "hidden_dropout_prob": 0.1,
59
+ "hidden_size": 768,
60
+ "initializer_range": 0.02,
61
+ "intermediate_size": 3072,
62
+ "layer_norm_eps": 1e-05,
63
+ "layerdrop": 0.1,
64
+ "mask_feature_length": 10,
65
+ "mask_feature_min_masks": 0,
66
+ "mask_feature_prob": 0.0,
67
+ "mask_time_length": 10,
68
+ "mask_time_min_masks": 2,
69
+ "mask_time_prob": 0.05,
70
+ "model_type": "wav2vec2",
71
+ "num_adapter_layers": 3,
72
+ "num_attention_heads": 12,
73
+ "num_codevector_groups": 2,
74
+ "num_codevectors_per_group": 320,
75
+ "num_conv_pos_embedding_groups": 16,
76
+ "num_conv_pos_embeddings": 128,
77
+ "num_feat_extract_layers": 7,
78
+ "num_hidden_layers": 12,
79
+ "num_negatives": 100,
80
+ "output_hidden_size": 768,
81
+ "pad_token_id": 0,
82
+ "proj_codevector_dim": 256,
83
+ "tdnn_dilation": [
84
+ 1,
85
+ 2,
86
+ 3,
87
+ 1,
88
+ 1
89
+ ],
90
+ "tdnn_dim": [
91
+ 512,
92
+ 512,
93
+ 512,
94
+ 512,
95
+ 1500
96
+ ],
97
+ "tdnn_kernel": [
98
+ 5,
99
+ 3,
100
+ 3,
101
+ 1,
102
+ 1
103
+ ],
104
+ "torch_dtype": "float32",
105
+ "transformers_version": "4.50.3",
106
+ "use_weighted_layer_sum": false,
107
+ "vocab_size": 32,
108
+ "xvector_output_dim": 512
109
+ }
models_for_proj/wav2vec2-base-960h/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75cf04071a643e1f23b8bb1571cde28cab80e3ff3a822ef0073d26f8fe98afdc
3
+ size 377611120
models_for_proj/wav2vec2-base-960h/preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "processor_class": "Wav2Vec2Processor",
8
+ "return_attention_mask": false,
9
+ "sampling_rate": 16000
10
+ }
models_for_proj/wav2vec2-base-960h/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "<unk>"
6
+ }
models_for_proj/wav2vec2-base-960h/tokenizer_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": false
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": true,
14
+ "normalized": false,
15
+ "rstrip": true,
16
+ "single_word": false,
17
+ "special": false
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": true,
22
+ "normalized": false,
23
+ "rstrip": true,
24
+ "single_word": false,
25
+ "special": false
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": true,
30
+ "normalized": false,
31
+ "rstrip": true,
32
+ "single_word": false,
33
+ "special": false
34
+ }
35
+ },
36
+ "bos_token": "<s>",
37
+ "clean_up_tokenization_spaces": false,
38
+ "do_lower_case": false,
39
+ "do_normalize": true,
40
+ "eos_token": "</s>",
41
+ "extra_special_tokens": {},
42
+ "model_max_length": 1000000000000000019884624838656,
43
+ "pad_token": "<pad>",
44
+ "processor_class": "Wav2Vec2Processor",
45
+ "replace_word_delimiter_char": " ",
46
+ "return_attention_mask": false,
47
+ "target_lang": null,
48
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
49
+ "unk_token": "<unk>",
50
+ "word_delimiter_token": "|"
51
+ }
models_for_proj/wav2vec2-base-960h/vocab.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "'": 27,
3
+ "</s>": 2,
4
+ "<pad>": 0,
5
+ "<s>": 1,
6
+ "<unk>": 3,
7
+ "A": 7,
8
+ "B": 24,
9
+ "C": 19,
10
+ "D": 14,
11
+ "E": 5,
12
+ "F": 20,
13
+ "G": 21,
14
+ "H": 11,
15
+ "I": 10,
16
+ "J": 29,
17
+ "K": 26,
18
+ "L": 15,
19
+ "M": 17,
20
+ "N": 9,
21
+ "O": 8,
22
+ "P": 23,
23
+ "Q": 30,
24
+ "R": 13,
25
+ "S": 12,
26
+ "T": 6,
27
+ "U": 16,
28
+ "V": 25,
29
+ "W": 18,
30
+ "X": 28,
31
+ "Y": 22,
32
+ "Z": 31,
33
+ "|": 4
34
+ }