mazesmazes commited on
Commit
f2f75b3
·
verified ·
1 Parent(s): 3ea0555

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_modeling.py +1 -6
  2. asr_pipeline.py +9 -30
asr_modeling.py CHANGED
@@ -622,18 +622,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
622
  user_content += " " + self.TRANSCRIBE_PROMPT
623
  messages.append({"role": "user", "content": user_content})
624
 
625
- enable_thinking_val = getattr(self.config, "enable_thinking", False)
626
- print(f"[DEBUG generate] enable_thinking={enable_thinking_val}, system_prompt={system_prompt[:100] if system_prompt else None}...")
627
  chat_result = self.tokenizer.apply_chat_template(
628
  messages,
629
  tokenize=True,
630
  add_generation_prompt=True,
631
  return_tensors="pt",
632
- enable_thinking=enable_thinking_val,
633
  )
634
- # Debug: show the formatted prompt
635
- prompt_text = self.tokenizer.decode(chat_result.input_ids[0] if chat_result.input_ids.dim() > 1 else chat_result.input_ids)
636
- print(f"[DEBUG generate] Formatted prompt: {prompt_text[:500]}...")
637
  input_ids = chat_result.input_ids.to(device)
638
 
639
  if input_ids.dim() == 1:
 
622
  user_content += " " + self.TRANSCRIBE_PROMPT
623
  messages.append({"role": "user", "content": user_content})
624
 
 
 
625
  chat_result = self.tokenizer.apply_chat_template(
626
  messages,
627
  tokenize=True,
628
  add_generation_prompt=True,
629
  return_tensors="pt",
630
+ enable_thinking=getattr(self.config, "enable_thinking", False),
631
  )
 
 
 
632
  input_ids = chat_result.input_ids.to(device)
633
 
634
  if input_ids.dim() == 1:
asr_pipeline.py CHANGED
@@ -18,30 +18,13 @@ except ImportError:
18
  from diarization import SpeakerDiarizer # type: ignore[no-redef]
19
 
20
  # Re-export for backwards compatibility
21
- __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking", "extract_thinking"]
22
 
23
  # Default TTS voice for Kokoro
24
  DEFAULT_TTS_VOICE = "af_heart"
25
  TTS_SAMPLE_RATE = 24000
26
 
27
 
28
- def extract_thinking(text: str) -> tuple[str, str]:
29
- """Extract thinking content from model output.
30
-
31
- Args:
32
- text: Model output text that may contain thinking tags
33
-
34
- Returns:
35
- Tuple of (thinking_content, response_text)
36
- """
37
- if not text:
38
- return "", ""
39
- match = re.search(r"<think>(.*?)</think>", text, flags=re.DOTALL)
40
- thinking = match.group(1).strip() if match else ""
41
- response = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
42
- return thinking, response
43
-
44
-
45
  def strip_thinking(text: str) -> str:
46
  """Remove <think>...</think> tags from model output.
47
 
@@ -51,8 +34,10 @@ def strip_thinking(text: str) -> str:
51
  Returns:
52
  Text with thinking content removed
53
  """
54
- _, response = extract_thinking(text)
55
- return response
 
 
56
 
57
 
58
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
@@ -518,17 +503,11 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
518
  tokens = [t for t in tokens.tolist() if t not in eos_set]
519
 
520
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
521
- # Debug: show raw text before extraction
522
- if "<think>" in text or "think" in text.lower()[:50]:
523
- print(f"[DEBUG postprocess] Raw text contains thinking: {text[:300]}...")
524
- # Extract thinking content before stripping
525
- thinking, response = extract_thinking(text)
526
  # Truncate repetitions at end of text
527
- response = _truncate_repetitions(response)
528
- result = {"text": response}
529
- if thinking:
530
- result["thinking"] = thinking
531
- return result
532
 
533
 
534
  def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
 
18
  from diarization import SpeakerDiarizer # type: ignore[no-redef]
19
 
20
  # Re-export for backwards compatibility
21
+ __all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking"]
22
 
23
  # Default TTS voice for Kokoro
24
  DEFAULT_TTS_VOICE = "af_heart"
25
  TTS_SAMPLE_RATE = 24000
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def strip_thinking(text: str) -> str:
29
  """Remove <think>...</think> tags from model output.
30
 
 
34
  Returns:
35
  Text with thinking content removed
36
  """
37
+ if not text:
38
+ return text
39
+ text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
40
+ return text.strip()
41
 
42
 
43
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
 
503
  tokens = [t for t in tokens.tolist() if t not in eos_set]
504
 
505
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
506
+ # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
507
+ text = strip_thinking(text)
 
 
 
508
  # Truncate repetitions at end of text
509
+ text = _truncate_repetitions(text)
510
+ return {"text": text}
 
 
 
511
 
512
 
513
  def _truncate_repetitions(text: str, min_repeats: int = 3) -> str: