mazesmazes commited on
Commit
f89f77f
·
verified ·
1 Parent(s): c293fd1

Update custom model files, README, and requirements

Browse files
Files changed (4) hide show
  1. asr_config.py +2 -4
  2. asr_modeling.py +5 -4
  3. asr_pipeline.py +28 -0
  4. asr_processing.py +0 -2
asr_config.py CHANGED
@@ -14,11 +14,10 @@ class ASRConfig(transformers.PretrainedConfig):
14
  attn_implementation: str = "flash_attention_2",
15
  model_dtype: str = "bfloat16",
16
  num_beams: Optional[int] = None,
17
- system_prompt: str = "/no_think /system_override",
18
  user_prompt: str = "Please transcribe this English audio into text: <audio>",
19
  encoder_dim: Optional[int] = None,
20
  llm_dim: Optional[int] = None,
21
- encoder_stride: int = 2, # Temporal downsampling factor of audio encoder (legacy, use encoder_conv_layers)
22
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
23
  # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
24
  encoder_conv_layers: Optional[list] = None,
@@ -52,7 +51,7 @@ class ASRConfig(transformers.PretrainedConfig):
52
  # Set default generation parameters (greedy decoding only)
53
  generation_defaults = {
54
  "num_beams": 1,
55
- "max_new_tokens": 96,
56
  "repetition_penalty": 1.0,
57
  "length_penalty": 1.0,
58
  "no_repeat_ngram_size": 0,
@@ -70,7 +69,6 @@ class ASRConfig(transformers.PretrainedConfig):
70
  self.user_prompt = user_prompt
71
  self.encoder_dim = encoder_dim
72
  self.llm_dim = llm_dim
73
- self.encoder_stride = encoder_stride
74
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
75
  self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
76
  self.audio_sample_rate = audio_sample_rate
 
14
  attn_implementation: str = "flash_attention_2",
15
  model_dtype: str = "bfloat16",
16
  num_beams: Optional[int] = None,
17
+ system_prompt: str = "You are a helpful transcription assistant",
18
  user_prompt: str = "Please transcribe this English audio into text: <audio>",
19
  encoder_dim: Optional[int] = None,
20
  llm_dim: Optional[int] = None,
 
21
  # Encoder conv layers: list of (padding, kernel_size, stride) tuples
22
  # Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
23
  encoder_conv_layers: Optional[list] = None,
 
51
  # Set default generation parameters (greedy decoding only)
52
  generation_defaults = {
53
  "num_beams": 1,
54
+ "max_new_tokens": 256,
55
  "repetition_penalty": 1.0,
56
  "length_penalty": 1.0,
57
  "no_repeat_ngram_size": 0,
 
69
  self.user_prompt = user_prompt
70
  self.encoder_dim = encoder_dim
71
  self.llm_dim = llm_dim
 
72
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
73
  self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
74
  self.audio_sample_rate = audio_sample_rate
asr_modeling.py CHANGED
@@ -96,7 +96,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
96
  super().__init__(config)
97
 
98
  self.system_prompt = config.system_prompt
99
- self.encoder_stride = config.encoder_stride
100
  target_dtype = getattr(torch, config.model_dtype)
101
 
102
  # Audio encoder (frozen)
@@ -121,7 +120,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
121
  self.generation_config.length_penalty = config.length_penalty
122
  self.generation_config.repetition_penalty = config.repetition_penalty
123
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
124
- self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
 
 
 
125
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
126
 
127
  # Feature extractor for audio preprocessing
@@ -145,7 +147,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
145
  encoder_kwargs = {
146
  "attn_implementation": config.attn_implementation,
147
  "low_cpu_mem_usage": True,
148
- "torch_dtype": dtype,
149
  }
150
 
151
  if "whisper" in config.audio_model_id.lower():
@@ -296,7 +298,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
296
  feature_extractor=self.feature_extractor,
297
  tokenizer=self.tokenizer,
298
  projector=self.projector,
299
- encoder_stride=self.encoder_stride,
300
  encoder_conv_layers=self.config.encoder_conv_layers,
301
  )
302
 
 
96
  super().__init__(config)
97
 
98
  self.system_prompt = config.system_prompt
 
99
  target_dtype = getattr(torch, config.model_dtype)
100
 
101
  # Audio encoder (frozen)
 
120
  self.generation_config.length_penalty = config.length_penalty
121
  self.generation_config.repetition_penalty = config.repetition_penalty
122
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
123
+ self.generation_config.eos_token_id = [
124
+ self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
125
+ self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
126
+ ]
127
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
128
 
129
  # Feature extractor for audio preprocessing
 
147
  encoder_kwargs = {
148
  "attn_implementation": config.attn_implementation,
149
  "low_cpu_mem_usage": True,
150
+ "dtype": dtype,
151
  }
152
 
153
  if "whisper" in config.audio_model_id.lower():
 
298
  feature_extractor=self.feature_extractor,
299
  tokenizer=self.tokenizer,
300
  projector=self.projector,
 
301
  encoder_conv_layers=self.config.encoder_conv_layers,
302
  )
303
 
asr_pipeline.py CHANGED
@@ -476,4 +476,32 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
 
 
479
  return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
+ # Truncate if a word repeats more than 3 times consecutively
480
+ text = self._truncate_repetitions(text, max_repeats=3)
481
  return {"text": text}
482
+
483
+ def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
484
+ """Truncate text when a word repeats more than max_repeats times consecutively.
485
+
486
+ Args:
487
+ text: Input text to check for repetitions
488
+ max_repeats: Maximum allowed consecutive repetitions (default 3)
489
+
490
+ Returns:
491
+ Truncated text if repetition detected, otherwise original text
492
+ """
493
+ words = text.split()
494
+ if len(words) <= max_repeats:
495
+ return text
496
+
497
+ repeat_count = 1
498
+ for i in range(1, len(words)):
499
+ if words[i].lower() == words[i - 1].lower():
500
+ repeat_count += 1
501
+ if repeat_count > max_repeats:
502
+ # Keep up to max_repeats of the repeated word
503
+ return " ".join(words[:i])
504
+ else:
505
+ repeat_count = 1
506
+
507
+ return text
asr_processing.py CHANGED
@@ -26,14 +26,12 @@ class ASRProcessor(ProcessorMixin):
26
  feature_extractor,
27
  tokenizer,
28
  projector=None,
29
- encoder_stride: int = 2,
30
  encoder_conv_layers: Optional[list] = None,
31
  ):
32
  self.feature_extractor = feature_extractor
33
  self.tokenizer = tokenizer
34
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
35
  self.projector = projector
36
- self.encoder_stride = encoder_stride # Legacy, kept for compatibility
37
  self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
38
 
39
  def _compute_encoder_output_length(self, mel_length: int) -> int:
 
26
  feature_extractor,
27
  tokenizer,
28
  projector=None,
 
29
  encoder_conv_layers: Optional[list] = None,
30
  ):
31
  self.feature_extractor = feature_extractor
32
  self.tokenizer = tokenizer
33
  self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
34
  self.projector = projector
 
35
  self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
36
 
37
  def _compute_encoder_output_length(self, mel_length: int) -> int: