Update custom model files, README, and requirements
Browse files- asr_modeling.py +1 -6
- asr_pipeline.py +9 -30
asr_modeling.py
CHANGED
|
@@ -622,18 +622,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 622 |
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 623 |
messages.append({"role": "user", "content": user_content})
|
| 624 |
|
| 625 |
-
enable_thinking_val = getattr(self.config, "enable_thinking", False)
|
| 626 |
-
print(f"[DEBUG generate] enable_thinking={enable_thinking_val}, system_prompt={system_prompt[:100] if system_prompt else None}...")
|
| 627 |
chat_result = self.tokenizer.apply_chat_template(
|
| 628 |
messages,
|
| 629 |
tokenize=True,
|
| 630 |
add_generation_prompt=True,
|
| 631 |
return_tensors="pt",
|
| 632 |
-
enable_thinking=
|
| 633 |
)
|
| 634 |
-
# Debug: show the formatted prompt
|
| 635 |
-
prompt_text = self.tokenizer.decode(chat_result.input_ids[0] if chat_result.input_ids.dim() > 1 else chat_result.input_ids)
|
| 636 |
-
print(f"[DEBUG generate] Formatted prompt: {prompt_text[:500]}...")
|
| 637 |
input_ids = chat_result.input_ids.to(device)
|
| 638 |
|
| 639 |
if input_ids.dim() == 1:
|
|
|
|
| 622 |
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 623 |
messages.append({"role": "user", "content": user_content})
|
| 624 |
|
|
|
|
|
|
|
| 625 |
chat_result = self.tokenizer.apply_chat_template(
|
| 626 |
messages,
|
| 627 |
tokenize=True,
|
| 628 |
add_generation_prompt=True,
|
| 629 |
return_tensors="pt",
|
| 630 |
+
enable_thinking=getattr(self.config, "enable_thinking", False),
|
| 631 |
)
|
|
|
|
|
|
|
|
|
|
| 632 |
input_ids = chat_result.input_ids.to(device)
|
| 633 |
|
| 634 |
if input_ids.dim() == 1:
|
asr_pipeline.py
CHANGED
|
@@ -18,30 +18,13 @@ except ImportError:
|
|
| 18 |
from diarization import SpeakerDiarizer # type: ignore[no-redef]
|
| 19 |
|
| 20 |
# Re-export for backwards compatibility
|
| 21 |
-
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking"
|
| 22 |
|
| 23 |
# Default TTS voice for Kokoro
|
| 24 |
DEFAULT_TTS_VOICE = "af_heart"
|
| 25 |
TTS_SAMPLE_RATE = 24000
|
| 26 |
|
| 27 |
|
| 28 |
-
def extract_thinking(text: str) -> tuple[str, str]:
|
| 29 |
-
"""Extract thinking content from model output.
|
| 30 |
-
|
| 31 |
-
Args:
|
| 32 |
-
text: Model output text that may contain thinking tags
|
| 33 |
-
|
| 34 |
-
Returns:
|
| 35 |
-
Tuple of (thinking_content, response_text)
|
| 36 |
-
"""
|
| 37 |
-
if not text:
|
| 38 |
-
return "", ""
|
| 39 |
-
match = re.search(r"<think>(.*?)</think>", text, flags=re.DOTALL)
|
| 40 |
-
thinking = match.group(1).strip() if match else ""
|
| 41 |
-
response = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
| 42 |
-
return thinking, response
|
| 43 |
-
|
| 44 |
-
|
| 45 |
def strip_thinking(text: str) -> str:
|
| 46 |
"""Remove <think>...</think> tags from model output.
|
| 47 |
|
|
@@ -51,8 +34,10 @@ def strip_thinking(text: str) -> str:
|
|
| 51 |
Returns:
|
| 52 |
Text with thinking content removed
|
| 53 |
"""
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
@@ -518,17 +503,11 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 518 |
tokens = [t for t in tokens.tolist() if t not in eos_set]
|
| 519 |
|
| 520 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 521 |
-
#
|
| 522 |
-
|
| 523 |
-
print(f"[DEBUG postprocess] Raw text contains thinking: {text[:300]}...")
|
| 524 |
-
# Extract thinking content before stripping
|
| 525 |
-
thinking, response = extract_thinking(text)
|
| 526 |
# Truncate repetitions at end of text
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
if thinking:
|
| 530 |
-
result["thinking"] = thinking
|
| 531 |
-
return result
|
| 532 |
|
| 533 |
|
| 534 |
def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
|
|
|
| 18 |
from diarization import SpeakerDiarizer # type: ignore[no-redef]
|
| 19 |
|
| 20 |
# Re-export for backwards compatibility
|
| 21 |
+
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline", "strip_thinking"]
|
| 22 |
|
| 23 |
# Default TTS voice for Kokoro
|
| 24 |
DEFAULT_TTS_VOICE = "af_heart"
|
| 25 |
TTS_SAMPLE_RATE = 24000
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def strip_thinking(text: str) -> str:
|
| 29 |
"""Remove <think>...</think> tags from model output.
|
| 30 |
|
|
|
|
| 34 |
Returns:
|
| 35 |
Text with thinking content removed
|
| 36 |
"""
|
| 37 |
+
if not text:
|
| 38 |
+
return text
|
| 39 |
+
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)
|
| 40 |
+
return text.strip()
|
| 41 |
|
| 42 |
|
| 43 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
|
|
| 503 |
tokens = [t for t in tokens.tolist() if t not in eos_set]
|
| 504 |
|
| 505 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 506 |
+
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 507 |
+
text = strip_thinking(text)
|
|
|
|
|
|
|
|
|
|
| 508 |
# Truncate repetitions at end of text
|
| 509 |
+
text = _truncate_repetitions(text)
|
| 510 |
+
return {"text": text}
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
|
| 513 |
def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|