Update custom model files, README, and requirements
Browse files- asr_config.py +8 -0
- asr_modeling.py +16 -13
asr_config.py
CHANGED
|
@@ -63,6 +63,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 63 |
lora_dropout: float = 0.0,
|
| 64 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 65 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
max_new_tokens: Optional[int] = None,
|
| 67 |
min_new_tokens: Optional[int] = None,
|
| 68 |
repetition_penalty: Optional[float] = None,
|
|
@@ -169,6 +173,10 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 169 |
else generation_defaults["no_repeat_ngram_size"]
|
| 170 |
)
|
| 171 |
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
if "audio_config" not in kwargs:
|
| 174 |
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
|
|
|
| 63 |
lora_dropout: float = 0.0,
|
| 64 |
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 65 |
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
| 66 |
+
do_sample: bool = False,
|
| 67 |
+
temperature: Optional[float] = None,
|
| 68 |
+
top_p: Optional[float] = None,
|
| 69 |
+
top_k: Optional[int] = None,
|
| 70 |
max_new_tokens: Optional[int] = None,
|
| 71 |
min_new_tokens: Optional[int] = None,
|
| 72 |
repetition_penalty: Optional[float] = None,
|
|
|
|
| 173 |
else generation_defaults["no_repeat_ngram_size"]
|
| 174 |
)
|
| 175 |
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
| 176 |
+
self.do_sample = do_sample
|
| 177 |
+
self.temperature = temperature
|
| 178 |
+
self.top_p = top_p
|
| 179 |
+
self.top_k = top_k
|
| 180 |
|
| 181 |
if "audio_config" not in kwargs:
|
| 182 |
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
asr_modeling.py
CHANGED
|
@@ -120,7 +120,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 120 |
super().__init__(config)
|
| 121 |
|
| 122 |
self.system_prompt = config.system_prompt
|
| 123 |
-
self.enable_thinking = False # Can be enabled for experimental thinking mode
|
| 124 |
target_dtype = getattr(torch, config.model_dtype)
|
| 125 |
|
| 126 |
# Audio encoder (frozen)
|
|
@@ -137,11 +136,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 137 |
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 138 |
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 139 |
self.generation_config.num_beams = config.num_beams
|
| 140 |
-
self.generation_config.do_sample =
|
| 141 |
-
#
|
| 142 |
-
self.generation_config.temperature =
|
| 143 |
-
self.generation_config.top_p =
|
| 144 |
-
self.generation_config.top_k =
|
| 145 |
self.generation_config.use_cache = config.use_cache
|
| 146 |
self.generation_config.length_penalty = config.length_penalty
|
| 147 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
|
@@ -554,7 +553,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 554 |
tokenize=True,
|
| 555 |
add_generation_prompt=True,
|
| 556 |
return_tensors="pt",
|
| 557 |
-
enable_thinking=
|
| 558 |
)
|
| 559 |
input_ids = chat_result.input_ids.to(device)
|
| 560 |
|
|
@@ -574,17 +573,21 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 574 |
)
|
| 575 |
|
| 576 |
# Generate using language model
|
|
|
|
|
|
|
| 577 |
output = self.language_model.generate(
|
|
|
|
| 578 |
inputs_embeds=inputs_embeds,
|
| 579 |
attention_mask=attention_mask,
|
| 580 |
generation_config=self.generation_config,
|
| 581 |
**generate_kwargs,
|
| 582 |
)
|
| 583 |
|
| 584 |
-
# When using inputs_embeds
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
|
|
|
| 588 |
|
| 589 |
def generate_streaming(
|
| 590 |
self,
|
|
@@ -632,7 +635,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 632 |
tokenize=True,
|
| 633 |
add_generation_prompt=True,
|
| 634 |
return_tensors="pt",
|
| 635 |
-
enable_thinking=
|
| 636 |
)
|
| 637 |
input_ids = chat_result.input_ids.to(device)
|
| 638 |
|
|
@@ -731,7 +734,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 731 |
tokenize=True,
|
| 732 |
add_generation_prompt=True,
|
| 733 |
return_tensors="pt",
|
| 734 |
-
enable_thinking=
|
| 735 |
).to(device)
|
| 736 |
|
| 737 |
if input_ids.dim() == 1:
|
|
|
|
| 120 |
super().__init__(config)
|
| 121 |
|
| 122 |
self.system_prompt = config.system_prompt
|
|
|
|
| 123 |
target_dtype = getattr(torch, config.model_dtype)
|
| 124 |
|
| 125 |
# Audio encoder (frozen)
|
|
|
|
| 136 |
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 137 |
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 138 |
self.generation_config.num_beams = config.num_beams
|
| 139 |
+
self.generation_config.do_sample = config.do_sample
|
| 140 |
+
# Set sampling params from config (None means use model defaults)
|
| 141 |
+
self.generation_config.temperature = config.temperature
|
| 142 |
+
self.generation_config.top_p = config.top_p
|
| 143 |
+
self.generation_config.top_k = config.top_k
|
| 144 |
self.generation_config.use_cache = config.use_cache
|
| 145 |
self.generation_config.length_penalty = config.length_penalty
|
| 146 |
self.generation_config.repetition_penalty = config.repetition_penalty
|
|
|
|
| 553 |
tokenize=True,
|
| 554 |
add_generation_prompt=True,
|
| 555 |
return_tensors="pt",
|
| 556 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 557 |
)
|
| 558 |
input_ids = chat_result.input_ids.to(device)
|
| 559 |
|
|
|
|
| 573 |
)
|
| 574 |
|
| 575 |
# Generate using language model
|
| 576 |
+
# Pass both input_ids and inputs_embeds so repetition_penalty works correctly
|
| 577 |
+
# (it needs input_ids to track which tokens have been used)
|
| 578 |
output = self.language_model.generate(
|
| 579 |
+
input_ids=input_ids,
|
| 580 |
inputs_embeds=inputs_embeds,
|
| 581 |
attention_mask=attention_mask,
|
| 582 |
generation_config=self.generation_config,
|
| 583 |
**generate_kwargs,
|
| 584 |
)
|
| 585 |
|
| 586 |
+
# When using inputs_embeds with input_ids, generate returns full sequence
|
| 587 |
+
# Strip the input tokens to return only generated tokens
|
| 588 |
+
sequences = output if isinstance(output, torch.Tensor) else output.sequences
|
| 589 |
+
input_len = input_ids.shape[1]
|
| 590 |
+
return sequences[:, input_len:]
|
| 591 |
|
| 592 |
def generate_streaming(
|
| 593 |
self,
|
|
|
|
| 635 |
tokenize=True,
|
| 636 |
add_generation_prompt=True,
|
| 637 |
return_tensors="pt",
|
| 638 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 639 |
)
|
| 640 |
input_ids = chat_result.input_ids.to(device)
|
| 641 |
|
|
|
|
| 734 |
tokenize=True,
|
| 735 |
add_generation_prompt=True,
|
| 736 |
return_tensors="pt",
|
| 737 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 738 |
).to(device)
|
| 739 |
|
| 740 |
if input_ids.dim() == 1:
|