Update custom model files, README, and requirements
Browse files- asr_modeling.py +4 -2
- asr_pipeline.py +3 -0
asr_modeling.py
CHANGED
|
@@ -168,7 +168,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 168 |
decoder_kwargs = {
|
| 169 |
"attn_implementation": config.attn_implementation,
|
| 170 |
"trust_remote_code": True,
|
| 171 |
-
"tie_word_embeddings":
|
| 172 |
"low_cpu_mem_usage": True,
|
| 173 |
"dtype": dtype,
|
| 174 |
}
|
|
@@ -342,7 +342,9 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 342 |
|
| 343 |
# Create valid mask for variable-length samples and extract only real embeddings
|
| 344 |
max_len = audio_embeds.shape[1]
|
| 345 |
-
valid_mask =
|
|
|
|
|
|
|
| 346 |
return audio_embeds[valid_mask]
|
| 347 |
|
| 348 |
def forward(
|
|
|
|
| 168 |
decoder_kwargs = {
|
| 169 |
"attn_implementation": config.attn_implementation,
|
| 170 |
"trust_remote_code": True,
|
| 171 |
+
"tie_word_embeddings": True,
|
| 172 |
"low_cpu_mem_usage": True,
|
| 173 |
"dtype": dtype,
|
| 174 |
}
|
|
|
|
| 342 |
|
| 343 |
# Create valid mask for variable-length samples and extract only real embeddings
|
| 344 |
max_len = audio_embeds.shape[1]
|
| 345 |
+
valid_mask = (
|
| 346 |
+
torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
|
| 347 |
+
)
|
| 348 |
return audio_embeds[valid_mask]
|
| 349 |
|
| 350 |
def forward(
|
asr_pipeline.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
from typing import Any
|
| 3 |
|
|
@@ -473,4 +474,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 473 |
tokens = tokens[0]
|
| 474 |
|
| 475 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
| 476 |
return {"text": text}
|
|
|
|
| 1 |
+
import re
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Any
|
| 4 |
|
|
|
|
| 474 |
tokens = tokens[0]
|
| 475 |
|
| 476 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 477 |
+
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 478 |
+
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
| 479 |
return {"text": text}
|