mazesmazes commited on
Commit
4fbcc1a
·
verified ·
1 Parent(s): cb0de5c

Training in progress - step 30500

Browse files
Files changed (5) hide show
  1. asr_config.py +3 -1
  2. asr_modeling.py +4 -5
  3. asr_pipeline.py +0 -28
  4. asr_processing.py +2 -0
  5. model.safetensors +1 -1
asr_config.py CHANGED
@@ -18,6 +18,7 @@ class ASRConfig(transformers.PretrainedConfig):
18
  user_prompt: str = "Please transcribe this English audio into text: <audio>",
19
  encoder_dim: Optional[int] = None,
20
  llm_dim: Optional[int] = None,
 
21
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
22
  # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
23
  encoder_conv_layers: Optional[list] = None,
@@ -51,7 +52,7 @@ class ASRConfig(transformers.PretrainedConfig):
51
  # Set default generation parameters (greedy decoding only)
52
  generation_defaults = {
53
  "num_beams": 1,
54
- "max_new_tokens": 256,
55
  "repetition_penalty": 1.0,
56
  "length_penalty": 1.0,
57
  "no_repeat_ngram_size": 0,
@@ -69,6 +70,7 @@ class ASRConfig(transformers.PretrainedConfig):
69
  self.user_prompt = user_prompt
70
  self.encoder_dim = encoder_dim
71
  self.llm_dim = llm_dim
 
72
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
73
  self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
74
  self.audio_sample_rate = audio_sample_rate
 
18
  user_prompt: str = "Please transcribe this English audio into text: <audio>",
19
  encoder_dim: Optional[int] = None,
20
  llm_dim: Optional[int] = None,
21
+ encoder_stride: int = 2, # Temporal downsampling factor of audio encoder (legacy, use encoder_conv_layers)
22
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
23
  # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
24
  encoder_conv_layers: Optional[list] = None,
 
52
  # Set default generation parameters (greedy decoding only)
53
  generation_defaults = {
54
  "num_beams": 1,
55
+ "max_new_tokens": 96,
56
  "repetition_penalty": 1.0,
57
  "length_penalty": 1.0,
58
  "no_repeat_ngram_size": 0,
 
70
  self.user_prompt = user_prompt
71
  self.encoder_dim = encoder_dim
72
  self.llm_dim = llm_dim
73
+ self.encoder_stride = encoder_stride
74
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
75
  self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
76
  self.audio_sample_rate = audio_sample_rate
asr_modeling.py CHANGED
@@ -96,6 +96,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
96
  super().__init__(config)
97
 
98
  self.system_prompt = config.system_prompt
 
99
  target_dtype = getattr(torch, config.model_dtype)
100
 
101
  # Audio encoder (frozen)
@@ -120,10 +121,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
120
  self.generation_config.length_penalty = config.length_penalty
121
  self.generation_config.repetition_penalty = config.repetition_penalty
122
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
123
- self.generation_config.eos_token_id = [
124
- self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
125
- self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
126
- ]
127
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
128
 
129
  # Feature extractor for audio preprocessing
@@ -147,7 +145,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
147
  encoder_kwargs = {
148
  "attn_implementation": config.attn_implementation,
149
  "low_cpu_mem_usage": True,
150
- "dtype": dtype,
151
  }
152
 
153
  if "whisper" in config.audio_model_id.lower():
@@ -298,6 +296,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
298
  feature_extractor=self.feature_extractor,
299
  tokenizer=self.tokenizer,
300
  projector=self.projector,
 
301
  encoder_conv_layers=self.config.encoder_conv_layers,
302
  )
303
 
 
96
  super().__init__(config)
97
 
98
  self.system_prompt = config.system_prompt
99
+ self.encoder_stride = config.encoder_stride
100
  target_dtype = getattr(torch, config.model_dtype)
101
 
102
  # Audio encoder (frozen)
 
121
  self.generation_config.length_penalty = config.length_penalty
122
  self.generation_config.repetition_penalty = config.repetition_penalty
123
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
124
+ self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
 
 
 
125
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
126
 
127
  # Feature extractor for audio preprocessing
 
145
  encoder_kwargs = {
146
  "attn_implementation": config.attn_implementation,
147
  "low_cpu_mem_usage": True,
148
+ "torch_dtype": dtype,
149
  }
150
 
151
  if "whisper" in config.audio_model_id.lower():
 
296
  feature_extractor=self.feature_extractor,
297
  tokenizer=self.tokenizer,
298
  projector=self.projector,
299
+ encoder_stride=self.encoder_stride,
300
  encoder_conv_layers=self.config.encoder_conv_layers,
301
  )
302
 
asr_pipeline.py CHANGED
@@ -476,32 +476,4 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
- # Truncate if a word repeats more than 3 times consecutively
480
- text = self._truncate_repetitions(text, max_repeats=3)
481
  return {"text": text}
482
-
483
- def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
484
- """Truncate text when a word repeats more than max_repeats times consecutively.
485
-
486
- Args:
487
- text: Input text to check for repetitions
488
- max_repeats: Maximum allowed consecutive repetitions (default 3)
489
-
490
- Returns:
491
- Truncated text if repetition detected, otherwise original text
492
- """
493
- words = text.split()
494
- if len(words) <= max_repeats:
495
- return text
496
-
497
- repeat_count = 1
498
- for i in range(1, len(words)):
499
- if words[i].lower() == words[i - 1].lower():
500
- repeat_count += 1
501
- if repeat_count > max_repeats:
502
- # Keep up to max_repeats of the repeated word
503
- return " ".join(words[:i])
504
- else:
505
- repeat_count = 1
506
-
507
- return text
 
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
 
 
479
  return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asr_processing.py CHANGED
@@ -26,12 +26,14 @@ class ASRProcessor(ProcessorMixin):
26
  feature_extractor,
27
  tokenizer,
28
  projector=None,
 
29
  encoder_conv_layers: Optional[list] = None,
30
  ):
31
  self.feature_extractor = feature_extractor
32
  self.tokenizer = tokenizer
33
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
34
  self.projector = projector
 
35
  self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
36
 
37
  def _compute_encoder_output_length(self, mel_length: int) -> int:
 
26
  feature_extractor,
27
  tokenizer,
28
  projector=None,
29
+ encoder_stride: int = 2,
30
  encoder_conv_layers: Optional[list] = None,
31
  ):
32
  self.feature_extractor = feature_extractor
33
  self.tokenizer = tokenizer
34
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
35
  self.projector = projector
36
+ self.encoder_stride = encoder_stride # Legacy, kept for compatibility
37
  self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
38
 
39
  def _compute_encoder_output_length(self, mel_length: int) -> int:
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ba6c7ffd625764146a13f4678f459fe084bec15e140db28016239aac516f158
3
  size 58732960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5ce422bc8492d610e980da968fec9e97cc628c00a928b8d9cdc24197ab5910c
3
  size 58732960