Sanchit Gandhi commited on
Commit ·
569a4c9
1
Parent(s): 2c11eb6
Add scripts and weights
Browse files
run_speech_recognition_whisper.py
CHANGED
|
@@ -23,6 +23,7 @@ import os
|
|
| 23 |
import whisper
|
| 24 |
import sys
|
| 25 |
from dataclasses import dataclass, field
|
|
|
|
| 26 |
|
| 27 |
from typing import Optional, Dict, Union, List
|
| 28 |
|
|
@@ -275,7 +276,6 @@ class WhisperDataCollatorWithPadding:
|
|
| 275 |
"""
|
| 276 |
|
| 277 |
eos_token_id: int
|
| 278 |
-
time_stamp_token_id: int
|
| 279 |
|
| 280 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 281 |
"""
|
|
@@ -626,9 +626,7 @@ def main():
|
|
| 626 |
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
| 627 |
|
| 628 |
# Define data collator
|
| 629 |
-
|
| 630 |
-
t_stamp = tokenizer("<|notimestamps|>").input_ids[0]
|
| 631 |
-
whisper_data_collator = WhisperDataCollatorWithPadding(eos_token_id=eos, time_stamp_token_id=t_stamp)
|
| 632 |
|
| 633 |
# make sure model uses 50257 as BOS
|
| 634 |
bos = tokenizer("<|startoftranscript|>").input_ids[0]
|
|
|
|
| 23 |
import whisper
|
| 24 |
import sys
|
| 25 |
from dataclasses import dataclass, field
|
| 26 |
+
import tempfile
|
| 27 |
|
| 28 |
from typing import Optional, Dict, Union, List
|
| 29 |
|
|
|
|
| 276 |
"""
|
| 277 |
|
| 278 |
eos_token_id: int
|
|
|
|
| 279 |
|
| 280 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 281 |
"""
|
|
|
|
| 626 |
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
| 627 |
|
| 628 |
# Define data collator
|
| 629 |
+
whisper_data_collator = WhisperDataCollatorWithPadding(eos_token_id=tokenizer.eos_token_id)
|
|
|
|
|
|
|
| 630 |
|
| 631 |
# make sure model uses 50257 as BOS
|
| 632 |
bos = tokenizer("<|startoftranscript|>").input_ids[0]
|