Commit ·
aac3fbb
1
Parent(s): 9e467b3
use decode to inspect ds
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -294,6 +294,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
| 294 |
processor: Any
|
| 295 |
decoder_start_token_id: int
|
| 296 |
task_id: int
|
|
|
|
|
|
|
| 297 |
|
| 298 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 299 |
# split inputs and labels since they have to be of different lengths and need
|
|
@@ -312,8 +314,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
| 312 |
|
| 313 |
# if bos token is appended in previous tokenization step,
|
| 314 |
# cut bos token here as it's append later anyways
|
| 315 |
-
|
| 316 |
-
# labels = labels[:, 1:]
|
| 317 |
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
| 318 |
# # Replace language and task if they are in the beginning, otherwise add them
|
| 319 |
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
|
@@ -328,6 +329,15 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
| 328 |
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
| 329 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
batch["labels"] = labels
|
| 332 |
|
| 333 |
return batch
|
|
@@ -461,7 +471,7 @@ def load_maybe_streaming_dataset(
|
|
| 461 |
def print_data_samples(dataset, tokenizer, max_samples=5):
|
| 462 |
shown_samples = 0
|
| 463 |
for batch in dataset:
|
| 464 |
-
print("Target: ", tokenizer.
|
| 465 |
shown_samples += len(batch)
|
| 466 |
if shown_samples >= max_samples:
|
| 467 |
break
|
|
|
|
| 294 |
processor: Any
|
| 295 |
decoder_start_token_id: int
|
| 296 |
task_id: int
|
| 297 |
+
# TODO: remove - infer language from dataset
|
| 298 |
+
language_id: int = -100
|
| 299 |
|
| 300 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 301 |
# split inputs and labels since they have to be of different lengths and need
|
|
|
|
| 314 |
|
| 315 |
# if bos token is appended in previous tokenization step,
|
| 316 |
# cut bos token here as it's append later anyways
|
| 317 |
+
|
|
|
|
| 318 |
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
| 319 |
# # Replace language and task if they are in the beginning, otherwise add them
|
| 320 |
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
|
|
|
| 329 |
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
| 330 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
| 331 |
|
| 332 |
+
# remove start of sentence token from labels
|
| 333 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
| 334 |
+
labels = labels[:, 1:]
|
| 335 |
+
|
| 336 |
+
# add start of sentence token to labels + language + task
|
| 337 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.task_id), labels), dim=1)
|
| 338 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.language_id), labels), dim=1)
|
| 339 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id), labels), dim=1)
|
| 340 |
+
|
| 341 |
batch["labels"] = labels
|
| 342 |
|
| 343 |
return batch
|
|
|
|
| 471 |
def print_data_samples(dataset, tokenizer, max_samples=5):
|
| 472 |
shown_samples = 0
|
| 473 |
for batch in dataset:
|
| 474 |
+
print("Target: ", tokenizer.decode(batch["labels"]))
|
| 475 |
shown_samples += len(batch)
|
| 476 |
if shown_samples >= max_samples:
|
| 477 |
break
|