Training in progress - step 500
Browse files- asr_modeling.py +12 -3
- asr_pipeline.py +3 -27
asr_modeling.py
CHANGED
|
@@ -212,8 +212,18 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 212 |
**cache_kwargs,
|
| 213 |
)
|
| 214 |
else:
|
| 215 |
-
# No saved adapters - initialize fresh LoRA for training
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
return model
|
| 219 |
finally:
|
|
@@ -382,7 +392,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 382 |
task_type="CAUSAL_LM",
|
| 383 |
)
|
| 384 |
self.language_model = get_peft_model(self.language_model, lora_config)
|
| 385 |
-
# LoRA params are trainable by default, base model stays frozen
|
| 386 |
|
| 387 |
def _init_tokenizer(self, config: ASRConfig):
|
| 388 |
"""Initialize tokenizer with audio token."""
|
|
|
|
| 212 |
**cache_kwargs,
|
| 213 |
)
|
| 214 |
else:
|
| 215 |
+
# No saved adapters - initialize fresh LLM LoRA for training
|
| 216 |
+
from peft import LoraConfig, get_peft_model
|
| 217 |
+
|
| 218 |
+
lora_config = LoraConfig(
|
| 219 |
+
r=config.lora_rank,
|
| 220 |
+
lora_alpha=config.lora_alpha,
|
| 221 |
+
target_modules=config.lora_target_modules,
|
| 222 |
+
lora_dropout=config.lora_dropout,
|
| 223 |
+
bias="none",
|
| 224 |
+
task_type="CAUSAL_LM",
|
| 225 |
+
)
|
| 226 |
+
model.language_model = get_peft_model(model.language_model, lora_config)
|
| 227 |
|
| 228 |
return model
|
| 229 |
finally:
|
|
|
|
| 392 |
task_type="CAUSAL_LM",
|
| 393 |
)
|
| 394 |
self.language_model = get_peft_model(self.language_model, lora_config)
|
|
|
|
| 395 |
|
| 396 |
def _init_tokenizer(self, config: ASRConfig):
|
| 397 |
"""Initialize tokenizer with audio token."""
|
asr_pipeline.py
CHANGED
|
@@ -485,40 +485,16 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 485 |
if not text:
|
| 486 |
return ""
|
| 487 |
|
| 488 |
-
original_len = len(text.split())
|
| 489 |
-
original_text = text # Keep for debug
|
| 490 |
-
|
| 491 |
# 1. LOWERCASE
|
| 492 |
text = text.lower()
|
| 493 |
|
| 494 |
-
# 2.
|
| 495 |
-
# If the model repeats the same phrase, keep only one instance.
|
| 496 |
-
words = text.split()
|
| 497 |
-
for n in range(1, min(15, len(words) // 2 + 1)):
|
| 498 |
-
last_sequence = words[-n:]
|
| 499 |
-
repeat_count = 0
|
| 500 |
-
idx = len(words) - n
|
| 501 |
-
while idx >= n and words[idx - n : idx] == last_sequence:
|
| 502 |
-
repeat_count += 1
|
| 503 |
-
idx -= n
|
| 504 |
-
|
| 505 |
-
if repeat_count >= 1:
|
| 506 |
-
words = words[: idx + n]
|
| 507 |
-
text = " ".join(words)
|
| 508 |
-
print(
|
| 509 |
-
f"[POSTPROCESS] Truncated repetition: {original_len} -> {len(words)} words (n={n}, repeats={repeat_count})"
|
| 510 |
-
)
|
| 511 |
-
print(f"[POSTPROCESS] Before: {original_text[:100]}...")
|
| 512 |
-
print(f"[POSTPROCESS] After: {text[:100]}...")
|
| 513 |
-
break
|
| 514 |
-
|
| 515 |
-
# 3. COMBINE ACRONYMS
|
| 516 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 517 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 518 |
|
| 519 |
-
#
|
| 520 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 521 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 522 |
|
| 523 |
-
#
|
| 524 |
return re.sub(r"\s+", " ", text).strip()
|
|
|
|
| 485 |
if not text:
|
| 486 |
return ""
|
| 487 |
|
|
|
|
|
|
|
|
|
|
| 488 |
# 1. LOWERCASE
|
| 489 |
text = text.lower()
|
| 490 |
|
| 491 |
+
# 2. COMBINE ACRONYMS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
# Merge consecutive single letters into one word (e.g., "u s a" -> "usa")
|
| 493 |
text = re.sub(r"\b([a-z])((?:\s+[a-z])+)\b", lambda m: m.group(0).replace(" ", ""), text)
|
| 494 |
|
| 495 |
+
# 3. NORMALIZE CURRENCY
|
| 496 |
# Convert "eur X" to "X euros" for Whisper normalizer compatibility
|
| 497 |
text = re.sub(r"\beur\s+(\d+)", r"\1 euros", text)
|
| 498 |
|
| 499 |
+
# 4. STRIP WHITESPACE
|
| 500 |
return re.sub(r"\s+", " ", text).strip()
|