mazesmazes commited on
Commit
866f0c8
·
verified ·
1 Parent(s): cf9c6ea

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. asr_modeling.py +4 -3
  2. asr_pipeline.py +10 -2
  3. diarization.py +1 -1
asr_modeling.py CHANGED
@@ -120,6 +120,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
120
  super().__init__(config)
121
 
122
  self.system_prompt = config.system_prompt
 
123
  target_dtype = getattr(torch, config.model_dtype)
124
 
125
  # Audio encoder (frozen)
@@ -553,7 +554,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
553
  tokenize=True,
554
  add_generation_prompt=True,
555
  return_tensors="pt",
556
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
557
  )
558
  input_ids = chat_result.input_ids.to(device)
559
 
@@ -631,7 +632,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
631
  tokenize=True,
632
  add_generation_prompt=True,
633
  return_tensors="pt",
634
- enable_thinking=False, # Disable Qwen3 thinking mode for ASR
635
  )
636
  input_ids = chat_result.input_ids.to(device)
637
 
@@ -730,7 +731,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
730
  tokenize=True,
731
  add_generation_prompt=True,
732
  return_tensors="pt",
733
- enable_thinking=False,
734
  ).to(device)
735
 
736
  if input_ids.dim() == 1:
 
120
  super().__init__(config)
121
 
122
  self.system_prompt = config.system_prompt
123
+ self.enable_thinking = False # Can be enabled for experimental thinking mode
124
  target_dtype = getattr(torch, config.model_dtype)
125
 
126
  # Audio encoder (frozen)
 
554
  tokenize=True,
555
  add_generation_prompt=True,
556
  return_tensors="pt",
557
+ enable_thinking=self.enable_thinking,
558
  )
559
  input_ids = chat_result.input_ids.to(device)
560
 
 
632
  tokenize=True,
633
  add_generation_prompt=True,
634
  return_tensors="pt",
635
+ enable_thinking=self.enable_thinking,
636
  )
637
  input_ids = chat_result.input_ids.to(device)
638
 
 
731
  tokenize=True,
732
  add_generation_prompt=True,
733
  return_tensors="pt",
734
+ enable_thinking=self.enable_thinking,
735
  ).to(device)
736
 
737
  if input_ids.dim() == 1:
asr_pipeline.py CHANGED
@@ -446,7 +446,9 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
446
  text = char_pattern.sub(r"\1", text)
447
 
448
  # 2. Truncate repeated words at end (e.g., "the the the" -> "the")
449
- word_pattern = re.compile(r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE)
 
 
450
  while word_pattern.search(text):
451
  text = word_pattern.sub(r"\1", text)
452
 
@@ -461,7 +463,13 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
461
  # Build pattern to match repeated phrases at end
462
  phrase_escaped = re.escape(phrase)
463
  phrase_pattern = re.compile(
464
- r"(^|.*?\s)(" + phrase_escaped + r")(?:\s+" + phrase_escaped + r"){" + str(min_repeats - 1) + r",}\s*$",
 
 
 
 
 
 
465
  re.IGNORECASE,
466
  )
467
  match = phrase_pattern.match(text)
 
446
  text = char_pattern.sub(r"\1", text)
447
 
448
  # 2. Truncate repeated words at end (e.g., "the the the" -> "the")
449
+ word_pattern = re.compile(
450
+ r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
451
+ )
452
  while word_pattern.search(text):
453
  text = word_pattern.sub(r"\1", text)
454
 
 
463
  # Build pattern to match repeated phrases at end
464
  phrase_escaped = re.escape(phrase)
465
  phrase_pattern = re.compile(
466
+ r"(^|.*?\s)("
467
+ + phrase_escaped
468
+ + r")(?:\s+"
469
+ + phrase_escaped
470
+ + r"){"
471
+ + str(min_repeats - 1)
472
+ + r",}\s*$",
473
  re.IGNORECASE,
474
  )
475
  match = phrase_pattern.match(text)
diarization.py CHANGED
@@ -737,7 +737,7 @@ class SpeakerDiarizer:
737
 
738
  cls._pyannote_pipeline = Pipeline.from_pretrained(
739
  "pyannote/speaker-diarization-3.1",
740
- use_auth_token=hf_token,
741
  )
742
  cls._pyannote_pipeline.to(torch.device(_get_device()))
743
 
 
737
 
738
  cls._pyannote_pipeline = Pipeline.from_pretrained(
739
  "pyannote/speaker-diarization-3.1",
740
+ token=hf_token,
741
  )
742
  cls._pyannote_pipeline.to(torch.device(_get_device()))
743