Update custom model files, README, and requirements
Browse files- asr_modeling.py +4 -3
- asr_pipeline.py +10 -2
- diarization.py +1 -1
asr_modeling.py
CHANGED
|
@@ -120,6 +120,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 120 |
super().__init__(config)
|
| 121 |
|
| 122 |
self.system_prompt = config.system_prompt
|
|
|
|
| 123 |
target_dtype = getattr(torch, config.model_dtype)
|
| 124 |
|
| 125 |
# Audio encoder (frozen)
|
|
@@ -553,7 +554,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 553 |
tokenize=True,
|
| 554 |
add_generation_prompt=True,
|
| 555 |
return_tensors="pt",
|
| 556 |
-
enable_thinking=
|
| 557 |
)
|
| 558 |
input_ids = chat_result.input_ids.to(device)
|
| 559 |
|
|
@@ -631,7 +632,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 631 |
tokenize=True,
|
| 632 |
add_generation_prompt=True,
|
| 633 |
return_tensors="pt",
|
| 634 |
-
enable_thinking=
|
| 635 |
)
|
| 636 |
input_ids = chat_result.input_ids.to(device)
|
| 637 |
|
|
@@ -730,7 +731,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 730 |
tokenize=True,
|
| 731 |
add_generation_prompt=True,
|
| 732 |
return_tensors="pt",
|
| 733 |
-
enable_thinking=
|
| 734 |
).to(device)
|
| 735 |
|
| 736 |
if input_ids.dim() == 1:
|
|
|
|
| 120 |
super().__init__(config)
|
| 121 |
|
| 122 |
self.system_prompt = config.system_prompt
|
| 123 |
+
self.enable_thinking = False # Can be enabled for experimental thinking mode
|
| 124 |
target_dtype = getattr(torch, config.model_dtype)
|
| 125 |
|
| 126 |
# Audio encoder (frozen)
|
|
|
|
| 554 |
tokenize=True,
|
| 555 |
add_generation_prompt=True,
|
| 556 |
return_tensors="pt",
|
| 557 |
+
enable_thinking=self.enable_thinking,
|
| 558 |
)
|
| 559 |
input_ids = chat_result.input_ids.to(device)
|
| 560 |
|
|
|
|
| 632 |
tokenize=True,
|
| 633 |
add_generation_prompt=True,
|
| 634 |
return_tensors="pt",
|
| 635 |
+
enable_thinking=self.enable_thinking,
|
| 636 |
)
|
| 637 |
input_ids = chat_result.input_ids.to(device)
|
| 638 |
|
|
|
|
| 731 |
tokenize=True,
|
| 732 |
add_generation_prompt=True,
|
| 733 |
return_tensors="pt",
|
| 734 |
+
enable_thinking=self.enable_thinking,
|
| 735 |
).to(device)
|
| 736 |
|
| 737 |
if input_ids.dim() == 1:
|
asr_pipeline.py
CHANGED
|
@@ -446,7 +446,9 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
|
| 446 |
text = char_pattern.sub(r"\1", text)
|
| 447 |
|
| 448 |
# 2. Truncate repeated words at end (e.g., "the the the" -> "the")
|
| 449 |
-
word_pattern = re.compile(
|
|
|
|
|
|
|
| 450 |
while word_pattern.search(text):
|
| 451 |
text = word_pattern.sub(r"\1", text)
|
| 452 |
|
|
@@ -461,7 +463,13 @@ def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
|
| 461 |
# Build pattern to match repeated phrases at end
|
| 462 |
phrase_escaped = re.escape(phrase)
|
| 463 |
phrase_pattern = re.compile(
|
| 464 |
-
r"(^|.*?\s)("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
re.IGNORECASE,
|
| 466 |
)
|
| 467 |
match = phrase_pattern.match(text)
|
|
|
|
| 446 |
text = char_pattern.sub(r"\1", text)
|
| 447 |
|
| 448 |
# 2. Truncate repeated words at end (e.g., "the the the" -> "the")
|
| 449 |
+
word_pattern = re.compile(
|
| 450 |
+
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
|
| 451 |
+
)
|
| 452 |
while word_pattern.search(text):
|
| 453 |
text = word_pattern.sub(r"\1", text)
|
| 454 |
|
|
|
|
| 463 |
# Build pattern to match repeated phrases at end
|
| 464 |
phrase_escaped = re.escape(phrase)
|
| 465 |
phrase_pattern = re.compile(
|
| 466 |
+
r"(^|.*?\s)("
|
| 467 |
+
+ phrase_escaped
|
| 468 |
+
+ r")(?:\s+"
|
| 469 |
+
+ phrase_escaped
|
| 470 |
+
+ r"){"
|
| 471 |
+
+ str(min_repeats - 1)
|
| 472 |
+
+ r",}\s*$",
|
| 473 |
re.IGNORECASE,
|
| 474 |
)
|
| 475 |
match = phrase_pattern.match(text)
|
diarization.py
CHANGED
|
@@ -737,7 +737,7 @@ class SpeakerDiarizer:
|
|
| 737 |
|
| 738 |
cls._pyannote_pipeline = Pipeline.from_pretrained(
|
| 739 |
"pyannote/speaker-diarization-3.1",
|
| 740 |
-
|
| 741 |
)
|
| 742 |
cls._pyannote_pipeline.to(torch.device(_get_device()))
|
| 743 |
|
|
|
|
| 737 |
|
| 738 |
cls._pyannote_pipeline = Pipeline.from_pretrained(
|
| 739 |
"pyannote/speaker-diarization-3.1",
|
| 740 |
+
token=hf_token,
|
| 741 |
)
|
| 742 |
cls._pyannote_pipeline.to(torch.device(_get_device()))
|
| 743 |
|