mazesmazes commited on
Commit
c19c0e3
·
verified ·
1 Parent(s): a1b0eab

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_modeling.py +4 -2
  2. asr_pipeline.py +3 -0
asr_modeling.py CHANGED
@@ -168,7 +168,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
168
  decoder_kwargs = {
169
  "attn_implementation": config.attn_implementation,
170
  "trust_remote_code": True,
171
- "tie_word_embeddings": False,
172
  "low_cpu_mem_usage": True,
173
  "dtype": dtype,
174
  }
@@ -342,7 +342,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
342
 
343
  # Create valid mask for variable-length samples and extract only real embeddings
344
  max_len = audio_embeds.shape[1]
345
- valid_mask = torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
 
 
346
  return audio_embeds[valid_mask]
347
 
348
  def forward(
 
168
  decoder_kwargs = {
169
  "attn_implementation": config.attn_implementation,
170
  "trust_remote_code": True,
171
+ "tie_word_embeddings": True,
172
  "low_cpu_mem_usage": True,
173
  "dtype": dtype,
174
  }
 
342
 
343
  # Create valid mask for variable-length samples and extract only real embeddings
344
  max_len = audio_embeds.shape[1]
345
+ valid_mask = (
346
+ torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
347
+ )
348
  return audio_embeds[valid_mask]
349
 
350
  def forward(
asr_pipeline.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pathlib import Path
2
  from typing import Any
3
 
@@ -473,4 +474,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
473
  tokens = tokens[0]
474
 
475
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
 
 
476
  return {"text": text}
 
1
+ import re
2
  from pathlib import Path
3
  from typing import Any
4
 
 
474
  tokens = tokens[0]
475
 
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
  return {"text": text}