Training in progress - step 1000
Browse files- asr_modeling.py +3 -2
asr_modeling.py
CHANGED
|
@@ -145,10 +145,12 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 145 |
self.generation_config.length_penalty = config.length_penalty
|
| 146 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 147 |
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 148 |
-
|
|
|
|
| 149 |
self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
|
| 150 |
self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
|
| 151 |
]
|
|
|
|
| 152 |
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 153 |
|
| 154 |
# Feature extractor for audio preprocessing
|
|
@@ -233,7 +235,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 233 |
decoder_kwargs = {
|
| 234 |
"attn_implementation": config.attn_implementation,
|
| 235 |
"trust_remote_code": True,
|
| 236 |
-
"tie_word_embeddings": False,
|
| 237 |
"low_cpu_mem_usage": True,
|
| 238 |
"dtype": dtype,
|
| 239 |
}
|
|
|
|
| 145 |
self.generation_config.length_penalty = config.length_penalty
|
| 146 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 147 |
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 148 |
+
# Set EOS tokens, filtering out any that don't exist in the tokenizer
|
| 149 |
+
eos_candidates = [
|
| 150 |
self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
|
| 151 |
self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
|
| 152 |
]
|
| 153 |
+
self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
|
| 154 |
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 155 |
|
| 156 |
# Feature extractor for audio preprocessing
|
|
|
|
| 235 |
decoder_kwargs = {
|
| 236 |
"attn_implementation": config.attn_implementation,
|
| 237 |
"trust_remote_code": True,
|
|
|
|
| 238 |
"low_cpu_mem_usage": True,
|
| 239 |
"dtype": dtype,
|
| 240 |
}
|