tlemagueresse commited on
Commit
280e76e
·
1 Parent(s): 39fb66e

Define batch_loader as method

Browse files
Files changed (3) hide show
  1. example_usage_fastmodel_hf.py +4 -9
  2. fast_model.py +100 -103
  3. pipeline.pkl +2 -2
example_usage_fastmodel_hf.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torchaudio
2
  from datasets import load_dataset
3
  from sklearn.metrics import accuracy_score
@@ -6,17 +8,10 @@ from fast_model import FastModelHuggingFace
6
  repo_id = "tlmk22/QuefrencyGuardian"
7
  fast_model = FastModelHuggingFace.from_pretrained(repo_id)
8
 
9
- # Example: predicting on a single WAV file
10
- wav_path = "wave_example/chainsaw.wav"
11
- waveform, sampling_rate = torchaudio.load(wav_path) # Charger le fichier audio
12
- if sampling_rate != 12000:
13
- resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=12000)
14
- waveform = resampler(waveform)
15
-
16
  # Perform predictions for a single WAV file
17
  map_labels = {0: "chainsaw", 1: "environment"}
18
- wav_prediction = fast_model.predict(waveform)
19
- print(f"Prediction : {map_labels[wav_prediction]}")
20
 
21
  # Example: predicting on a Hugging Face dataset
22
  dataset = load_dataset("rfcx/frugalai")
 
1
+ from pathlib import Path
2
+
3
  import torchaudio
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
 
8
  repo_id = "tlmk22/QuefrencyGuardian"
9
  fast_model = FastModelHuggingFace.from_pretrained(repo_id)
10
 
 
 
 
 
 
 
 
11
  # Perform predictions for a single WAV file
12
  map_labels = {0: "chainsaw", 1: "environment"}
13
+ wav_prediction = fast_model.predict("wav_example/chainsaw.wav", device="cpu")
14
+ print(f"Prediction : {map_labels[wav_prediction[0]]}")
15
 
16
  # Example: predicting on a Hugging Face dataset
17
  dataset = load_dataset("rfcx/frugalai")
fast_model.py CHANGED
@@ -117,12 +117,9 @@ class FastModel:
117
  If the dataset is empty or invalid.
118
  """
119
  features, labels = [], []
120
- for audio, label in batch_audio_loader(
121
  dataset,
122
- waveform_duration=self.audio_processing_params["duration"],
123
  batch_size=batch_size,
124
- padding_method=self.audio_processing_params["padding_method"],
125
- device=self.device,
126
  ):
127
  feature = self.get_features(audio)
128
  features.append(feature)
@@ -157,12 +154,9 @@ class FastModel:
157
  if not self.model:
158
  raise NotFittedError("LGBM model is not fitted yet.")
159
  features = []
160
- for audio, _ in batch_audio_loader(
161
  dataset,
162
- waveform_duration=self.audio_processing_params["duration"],
163
  batch_size=batch_size,
164
- padding_method=self.audio_processing_params["padding_method"],
165
- device=self.device,
166
  ):
167
  feature = self.get_features(audio)
168
  features.append(feature)
@@ -207,115 +201,118 @@ class FastModel:
207
  dim=1,
208
  )
209
 
 
 
 
 
 
 
 
210
 
211
- def batch_audio_loader(
212
- dataset: Dataset,
213
- waveform_duration: int = 3,
214
- batch_size: int = 1,
215
- sr: int = 12000,
216
- device: Literal["cpu", "cuda"] = "cpu",
217
- padding_method: None | Literal["zero", "reflect", "replicate", "circular"] = None,
218
- offset: int = 0,
219
- ):
220
- """Optimized loader for audio data from a dataset for training or inference in batches.
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- Parameters
223
- ----------
224
- dataset : Dataset
225
- The dataset containing audio samples and labels.
226
- waveform_duration : int, optional
227
- Desired duration of the audio waveforms in seconds (default is 3).
228
- batch_size : int, optional
229
- Number of audio samples per batch (default is 1).
230
- sr : int, optional
231
- Target sampling rate for audio processing (default is 12000).
232
- device : str, optional
233
- Device for processing ("cpu" or "cuda") (default is "cpu").
234
- padding_method : str, optional
235
- Method to pad audio waveforms smaller than the desired size (e.g., "zero", "reflect").
236
- offset : int, optional
237
- Number of samples to skip before processing the first audio sample (default is 0).
238
-
239
- Yields
240
- ------
241
- tuple (Tensor, Tensor)
242
- A tuple (batch_audios, batch_labels), where:
243
- - batch_audios is a torch.tensor of processed audio waveforms.
244
- - batch_labels is a torch.tensor of corresponding audio labels.
245
 
246
- Raises
247
- ------
248
- ValueError
249
- If an unsupported sampling rate is encountered in the dataset.
250
- """
 
 
 
251
 
252
- def process_resampling(resample_buffer, resample_indices, batch_audios, sr, target_sr):
253
- if resample_buffer:
254
- resampler = torchaudio.transforms.Resample(
255
- orig_freq=sr, new_freq=target_sr, lowpass_filter_width=6
256
- )
257
- resampled = resampler(torch.stack(resample_buffer))
258
- for idx, original_idx in enumerate(resample_indices):
259
- batch_audios[original_idx] = resampled[idx]
260
-
261
- device = torch.device("cuda" if device == "cuda" and torch.cuda.is_available() else "cpu")
262
- batch_audios, batch_labels = [], []
263
- resample_24000, resample_24000_indices = [], []
264
-
265
- for i in range(len(dataset)):
266
- pa_subtable = query_table(dataset._data, i, indices=dataset._indices)
267
- wav_bytes = pa_subtable[0][0][0].as_py()
268
- sampling_rate = struct.unpack("<I", wav_bytes[24:28])[0]
269
-
270
- if sampling_rate not in [sr, sr * 2]:
271
- raise ValueError(
272
- f"Unsupported sampling rate: {sampling_rate}Hz. Only {sr}Hz and {sr * 2}Hz are allowed."
273
- )
274
 
275
- data_size = struct.unpack("<I", wav_bytes[40:44])[0] // 2
276
- if data_size == 0:
277
- batch_audios.append(torch.zeros(int(waveform_duration * SR)))
278
- else:
279
- try:
280
- waveform = (
281
- torch.frombuffer(wav_bytes[44:], dtype=torch.int16, offset=offset)[
282
- : int(waveform_duration * sampling_rate)
283
- ].float()
284
- / 32767
 
 
 
 
285
  )
286
- except Exception as e:
287
- continue # May append during fit for small audios. offset is set to 0 during predict.
288
- waveform = apply_padding(
289
- waveform, int(waveform_duration * sampling_rate), padding_method
290
- )
291
 
292
- if sampling_rate == sr:
293
- batch_audios.append(waveform)
294
- elif sampling_rate == 2 * sr:
295
- resample_24000.append(waveform)
296
- resample_24000_indices.append(len(batch_audios))
297
- batch_audios.append(None)
 
 
 
 
 
 
 
 
 
 
298
 
299
- batch_labels.append(pa_subtable[1][0].as_py())
 
 
 
 
 
300
 
301
- if len(batch_audios) == batch_size:
302
- # Perform resampling once and take advantage of Torch's vectorization capabilities.
303
- process_resampling(resample_24000, resample_24000_indices, batch_audios, sr * 2, SR)
304
 
305
- batch_audios_on_device = torch.stack(batch_audios).to(device)
306
- batch_labels_on_device = torch.tensor(batch_labels).to(device)
 
307
 
308
- yield batch_audios_on_device, batch_labels_on_device
 
309
 
310
- batch_audios, batch_labels = [], []
311
- resample_24000, resample_24000_indices = [], []
312
 
313
- if batch_audios:
314
- process_resampling(resample_24000, resample_24000_indices, batch_audios, sr * 2, SR)
315
- batch_audios_on_device = torch.stack(batch_audios).to(device)
316
- batch_labels_on_device = torch.tensor(batch_labels).to(device)
317
 
318
- yield batch_audios_on_device, batch_labels_on_device
 
 
 
 
 
319
 
320
 
321
  def apply_padding(
 
117
  If the dataset is empty or invalid.
118
  """
119
  features, labels = [], []
120
+ for audio, label in self.batch_audio_loader(
121
  dataset,
 
122
  batch_size=batch_size,
 
 
123
  ):
124
  feature = self.get_features(audio)
125
  features.append(feature)
 
154
  if not self.model:
155
  raise NotFittedError("LGBM model is not fitted yet.")
156
  features = []
157
+ for audio, _ in self.batch_audio_loader(
158
  dataset,
 
159
  batch_size=batch_size,
 
 
160
  ):
161
  feature = self.get_features(audio)
162
  features.append(feature)
 
201
  dim=1,
202
  )
203
 
204
+ def batch_audio_loader(
205
+ self,
206
+ dataset: Dataset,
207
+ batch_size: int = 1,
208
+ offset: int = 0,
209
+ ):
210
+ """Optimized loader for audio data from a dataset for training or inference in batches.
211
 
212
+ Parameters
213
+ ----------
214
+ dataset : Dataset
215
+ The dataset containing audio samples and labels.
216
+ waveform_duration : int, optional
217
+ Desired duration of the audio waveforms in seconds (default is 3).
218
+ batch_size : int, optional
219
+ Number of audio samples per batch (default is 1).
220
+ sr : int, optional
221
+ Target sampling rate for audio processing (default is 12000).
222
+ device : str, optional
223
+ Device for processing ("cpu" or "cuda") (default is "cpu").
224
+ padding_method : str, optional
225
+ Method to pad audio waveforms smaller than the desired size (e.g., "zero", "reflect").
226
+ offset : int, optional
227
+ Number of samples to skip before processing the first audio sample (default is 0).
228
+
229
+ Yields
230
+ ------
231
+ tuple (Tensor, Tensor)
232
+ A tuple (batch_audios, batch_labels), where:
233
+ - batch_audios is a torch.tensor of processed audio waveforms.
234
+ - batch_labels is a torch.tensor of corresponding audio labels.
235
 
236
+ Raises
237
+ ------
238
+ ValueError
239
+ If an unsupported sampling rate is encountered in the dataset.
240
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ def process_resampling(resample_buffer, resample_indices, batch_audios, sr, target_sr):
243
+ if resample_buffer:
244
+ resampler = torchaudio.transforms.Resample(
245
+ orig_freq=sr, new_freq=target_sr, lowpass_filter_width=6
246
+ )
247
+ resampled = resampler(torch.stack(resample_buffer))
248
+ for idx, original_idx in enumerate(resample_indices):
249
+ batch_audios[original_idx] = resampled[idx]
250
 
251
+ # For readability
252
+ sr = self.audio_processing_params["sample_rate"]
253
+ waveform_duration = self.audio_processing_params["duration"]
254
+ padding_method = self.audio_processing_params["padding_method"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ device = torch.device(
257
+ "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
258
+ )
259
+ batch_audios, batch_labels = [], []
260
+ resample_24000, resample_24000_indices = [], []
261
+
262
+ for i in range(len(dataset)):
263
+ pa_subtable = query_table(dataset._data, i, indices=dataset._indices)
264
+ wav_bytes = pa_subtable[0][0][0].as_py()
265
+ sampling_rate = struct.unpack("<I", wav_bytes[24:28])[0]
266
+
267
+ if sampling_rate not in [sr, sr * 2]:
268
+ raise ValueError(
269
+ f"Unsupported sampling rate: {sampling_rate}Hz. Only {sr}Hz and {sr * 2}Hz are allowed."
270
  )
 
 
 
 
 
271
 
272
+ data_size = struct.unpack("<I", wav_bytes[40:44])[0] // 2
273
+ if data_size == 0:
274
+ batch_audios.append(torch.zeros(int(waveform_duration * SR)))
275
+ else:
276
+ try:
277
+ waveform = (
278
+ torch.frombuffer(wav_bytes[44:], dtype=torch.int16, offset=offset)[
279
+ : int(waveform_duration * sampling_rate)
280
+ ].float()
281
+ / 32767
282
+ )
283
+ except Exception as e:
284
+ continue # May append during fit for small audios. offset is set to 0 during predict.
285
+ waveform = apply_padding(
286
+ waveform, int(waveform_duration * sampling_rate), padding_method
287
+ )
288
 
289
+ if sampling_rate == sr:
290
+ batch_audios.append(waveform)
291
+ elif sampling_rate == 2 * sr:
292
+ resample_24000.append(waveform)
293
+ resample_24000_indices.append(len(batch_audios))
294
+ batch_audios.append(None)
295
 
296
+ batch_labels.append(pa_subtable[1][0].as_py())
 
 
297
 
298
+ if len(batch_audios) == batch_size:
299
+ # Perform resampling once and take advantage of Torch's vectorization capabilities.
300
+ process_resampling(resample_24000, resample_24000_indices, batch_audios, sr * 2, SR)
301
 
302
+ batch_audios_on_device = torch.stack(batch_audios).to(device)
303
+ batch_labels_on_device = torch.tensor(batch_labels).to(device)
304
 
305
+ yield batch_audios_on_device, batch_labels_on_device
 
306
 
307
+ batch_audios, batch_labels = [], []
308
+ resample_24000, resample_24000_indices = [], []
 
 
309
 
310
+ if batch_audios:
311
+ process_resampling(resample_24000, resample_24000_indices, batch_audios, sr * 2, SR)
312
+ batch_audios_on_device = torch.stack(batch_audios).to(device)
313
+ batch_labels_on_device = torch.tensor(batch_labels).to(device)
314
+
315
+ yield batch_audios_on_device, batch_labels_on_device
316
 
317
 
318
  def apply_padding(
pipeline.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3243c0fd7f6cafa8492132711b0376da91838029cfe1362e2fc19ee6bf847894
3
- size 834063
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37040a799b897c6902c1b520cf902223f145c5d61831f0c316317a9d999d8d61
3
+ size 834075