mazesmazes commited on
Commit
cfdc978
·
verified ·
1 Parent(s): 2d5e50f

Training in progress - step 5000

Browse files
Files changed (4) hide show
  1. asr_modeling.py +8 -4
  2. config.json +1 -1
  3. generation_config.json +1 -1
  4. model.safetensors +1 -1
asr_modeling.py CHANGED
@@ -573,17 +573,21 @@ class ASRModel(PreTrainedModel, GenerationMixin):
573
  )
574
 
575
  # Generate using language model
 
 
576
  output = self.language_model.generate(
 
577
  inputs_embeds=inputs_embeds,
578
  attention_mask=attention_mask,
579
  generation_config=self.generation_config,
580
  **generate_kwargs,
581
  )
582
 
583
- # When using inputs_embeds without input_ids, generate returns only new tokens
584
- if isinstance(output, torch.Tensor):
585
- return output
586
- return output.sequences
 
587
 
588
  def generate_streaming(
589
  self,
 
573
  )
574
 
575
  # Generate using language model
576
+ # Pass both input_ids and inputs_embeds so repetition_penalty works correctly
577
+ # (it needs input_ids to track which tokens have been used)
578
  output = self.language_model.generate(
579
+ input_ids=input_ids,
580
  inputs_embeds=inputs_embeds,
581
  attention_mask=attention_mask,
582
  generation_config=self.generation_config,
583
  **generate_kwargs,
584
  )
585
 
586
+ # When using inputs_embeds with input_ids, generate returns full sequence
587
+ # Strip the input tokens to return only generated tokens
588
+ sequences = output if isinstance(output, torch.Tensor) else output.sequences
589
+ input_len = input_ids.shape[1]
590
+ return sequences[:, input_len:]
591
 
592
  def generate_streaming(
593
  self,
config.json CHANGED
@@ -274,7 +274,7 @@
274
  "qformer_num_heads": 16,
275
  "qformer_num_layers": 2,
276
  "qformer_window_size": 15,
277
- "repetition_penalty": 1.0,
278
  "router_aux_loss_coef": 0.01,
279
  "system_prompt": "",
280
  "temperature": 0.7,
 
274
  "qformer_num_heads": 16,
275
  "qformer_num_layers": 2,
276
  "qformer_window_size": 15,
277
+ "repetition_penalty": 1.05,
278
  "router_aux_loss_coef": 0.01,
279
  "system_prompt": "",
280
  "temperature": 0.7,
generation_config.json CHANGED
@@ -11,7 +11,7 @@
11
  "no_repeat_ngram_size": 0,
12
  "num_beams": 1,
13
  "pad_token_id": 151643,
14
- "repetition_penalty": 1.0,
15
  "temperature": 0.7,
16
  "transformers_version": "5.0.0.dev0",
17
  "use_cache": true
 
11
  "no_repeat_ngram_size": 0,
12
  "num_beams": 1,
13
  "pad_token_id": 151643,
14
+ "repetition_penalty": 1.05,
15
  "temperature": 0.7,
16
  "transformers_version": "5.0.0.dev0",
17
  "use_cache": true
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8d27e4a7c907ced9bf93828c060f517ab52e14aa1d2d507a1ce23f8ae3f9435
3
  size 58732960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1c29578f6e4473b5f6a25ba03515832cfc1c5698f516d02f7758722d09b7065
3
  size 58732960