mazesmazes commited on
Commit
8f52147
·
verified ·
1 Parent(s): 8be1490

Training in progress - step 1000

Browse files
Files changed (1) hide show
  1. asr_modeling.py +3 -2
asr_modeling.py CHANGED
@@ -145,10 +145,12 @@ class ASRModel(PreTrainedModel, GenerationMixin):
145
  self.generation_config.length_penalty = config.length_penalty
146
  self.generation_config.repetition_penalty = config.repetition_penalty
147
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
148
- self.generation_config.eos_token_id = [
 
149
  self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
150
  self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
151
  ]
 
152
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
153
 
154
  # Feature extractor for audio preprocessing
@@ -233,7 +235,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
233
  decoder_kwargs = {
234
  "attn_implementation": config.attn_implementation,
235
  "trust_remote_code": True,
236
- "tie_word_embeddings": False,
237
  "low_cpu_mem_usage": True,
238
  "dtype": dtype,
239
  }
 
145
  self.generation_config.length_penalty = config.length_penalty
146
  self.generation_config.repetition_penalty = config.repetition_penalty
147
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
148
+ # Set EOS tokens, filtering out any that don't exist in the tokenizer
149
+ eos_candidates = [
150
  self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
151
  self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
152
  ]
153
+ self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
154
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
155
 
156
  # Feature extractor for audio preprocessing
 
235
  decoder_kwargs = {
236
  "attn_implementation": config.attn_implementation,
237
  "trust_remote_code": True,
 
238
  "low_cpu_mem_usage": True,
239
  "dtype": dtype,
240
  }