mazesmazes commited on
Commit
8a4ea40
·
verified ·
1 Parent(s): 5d77fcc

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_config.py +8 -0
  2. asr_modeling.py +16 -13
asr_config.py CHANGED
@@ -63,6 +63,10 @@ class ASRConfig(transformers.PretrainedConfig):
63
  lora_dropout: float = 0.0,
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
 
 
 
 
66
  max_new_tokens: Optional[int] = None,
67
  min_new_tokens: Optional[int] = None,
68
  repetition_penalty: Optional[float] = None,
@@ -169,6 +173,10 @@ class ASRConfig(transformers.PretrainedConfig):
169
  else generation_defaults["no_repeat_ngram_size"]
170
  )
171
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
 
 
 
 
172
 
173
  if "audio_config" not in kwargs:
174
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
 
63
  lora_dropout: float = 0.0,
64
  lora_target_modules: Optional[list] = None, # Default: all linear layers
65
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
66
+ do_sample: bool = False,
67
+ temperature: Optional[float] = None,
68
+ top_p: Optional[float] = None,
69
+ top_k: Optional[int] = None,
70
  max_new_tokens: Optional[int] = None,
71
  min_new_tokens: Optional[int] = None,
72
  repetition_penalty: Optional[float] = None,
 
173
  else generation_defaults["no_repeat_ngram_size"]
174
  )
175
  self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
176
+ self.do_sample = do_sample
177
+ self.temperature = temperature
178
+ self.top_p = top_p
179
+ self.top_k = top_k
180
 
181
  if "audio_config" not in kwargs:
182
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
asr_modeling.py CHANGED
@@ -120,7 +120,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
120
  super().__init__(config)
121
 
122
  self.system_prompt = config.system_prompt
123
- self.enable_thinking = False # Can be enabled for experimental thinking mode
124
  target_dtype = getattr(torch, config.model_dtype)
125
 
126
  # Audio encoder (frozen)
@@ -137,11 +136,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
137
  self.generation_config.max_new_tokens = config.max_new_tokens
138
  self.generation_config.min_new_tokens = config.min_new_tokens
139
  self.generation_config.num_beams = config.num_beams
140
- self.generation_config.do_sample = False
141
- # Clear sampling params (inherited from LLM) since we use greedy decoding
142
- self.generation_config.temperature = None
143
- self.generation_config.top_p = None
144
- self.generation_config.top_k = None
145
  self.generation_config.use_cache = config.use_cache
146
  self.generation_config.length_penalty = config.length_penalty
147
  self.generation_config.repetition_penalty = config.repetition_penalty
@@ -554,7 +553,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
554
  tokenize=True,
555
  add_generation_prompt=True,
556
  return_tensors="pt",
557
- enable_thinking=self.enable_thinking,
558
  )
559
  input_ids = chat_result.input_ids.to(device)
560
 
@@ -574,17 +573,21 @@ class ASRModel(PreTrainedModel, GenerationMixin):
574
  )
575
 
576
  # Generate using language model
 
 
577
  output = self.language_model.generate(
 
578
  inputs_embeds=inputs_embeds,
579
  attention_mask=attention_mask,
580
  generation_config=self.generation_config,
581
  **generate_kwargs,
582
  )
583
 
584
- # When using inputs_embeds without input_ids, generate returns only new tokens
585
- if isinstance(output, torch.Tensor):
586
- return output
587
- return output.sequences
 
588
 
589
  def generate_streaming(
590
  self,
@@ -632,7 +635,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
632
  tokenize=True,
633
  add_generation_prompt=True,
634
  return_tensors="pt",
635
- enable_thinking=self.enable_thinking,
636
  )
637
  input_ids = chat_result.input_ids.to(device)
638
 
@@ -731,7 +734,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
731
  tokenize=True,
732
  add_generation_prompt=True,
733
  return_tensors="pt",
734
- enable_thinking=self.enable_thinking,
735
  ).to(device)
736
 
737
  if input_ids.dim() == 1:
 
120
  super().__init__(config)
121
 
122
  self.system_prompt = config.system_prompt
 
123
  target_dtype = getattr(torch, config.model_dtype)
124
 
125
  # Audio encoder (frozen)
 
136
  self.generation_config.max_new_tokens = config.max_new_tokens
137
  self.generation_config.min_new_tokens = config.min_new_tokens
138
  self.generation_config.num_beams = config.num_beams
139
+ self.generation_config.do_sample = config.do_sample
140
+ # Set sampling params from config (None means use model defaults)
141
+ self.generation_config.temperature = config.temperature
142
+ self.generation_config.top_p = config.top_p
143
+ self.generation_config.top_k = config.top_k
144
  self.generation_config.use_cache = config.use_cache
145
  self.generation_config.length_penalty = config.length_penalty
146
  self.generation_config.repetition_penalty = config.repetition_penalty
 
553
  tokenize=True,
554
  add_generation_prompt=True,
555
  return_tensors="pt",
556
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
557
  )
558
  input_ids = chat_result.input_ids.to(device)
559
 
 
573
  )
574
 
575
  # Generate using language model
576
+ # Pass both input_ids and inputs_embeds so repetition_penalty works correctly
577
+ # (it needs input_ids to track which tokens have been used)
578
  output = self.language_model.generate(
579
+ input_ids=input_ids,
580
  inputs_embeds=inputs_embeds,
581
  attention_mask=attention_mask,
582
  generation_config=self.generation_config,
583
  **generate_kwargs,
584
  )
585
 
586
+ # When using inputs_embeds with input_ids, generate returns full sequence
587
+ # Strip the input tokens to return only generated tokens
588
+ sequences = output if isinstance(output, torch.Tensor) else output.sequences
589
+ input_len = input_ids.shape[1]
590
+ return sequences[:, input_len:]
591
 
592
  def generate_streaming(
593
  self,
 
635
  tokenize=True,
636
  add_generation_prompt=True,
637
  return_tensors="pt",
638
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
639
  )
640
  input_ids = chat_result.input_ids.to(device)
641
 
 
734
  tokenize=True,
735
  add_generation_prompt=True,
736
  return_tensors="pt",
737
+ enable_thinking=False, # Disable Qwen3 thinking mode for ASR
738
  ).to(device)
739
 
740
  if input_ids.dim() == 1: