Training in progress - step 1000
Browse files- asr_modeling.py +4 -21
- asr_pipeline.py +1 -3
- asr_processing.py +1 -1
asr_modeling.py
CHANGED
|
@@ -51,9 +51,7 @@ def _compute_mask_indices(
|
|
| 51 |
raise ValueError(f"mask_length must be >= 1, got {mask_length}")
|
| 52 |
|
| 53 |
if mask_length > sequence_length:
|
| 54 |
-
raise ValueError(
|
| 55 |
-
f"mask_length {mask_length} must be <= sequence_length {sequence_length}"
|
| 56 |
-
)
|
| 57 |
|
| 58 |
# Compute number of masked spans per sample
|
| 59 |
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
|
|
@@ -190,21 +188,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 190 |
state_dict = load_file(model_file)
|
| 191 |
model.load_state_dict(state_dict, strict=False)
|
| 192 |
|
| 193 |
-
# Load LoRA adapter if present
|
| 194 |
-
adapter_config = cached_file(
|
| 195 |
-
pretrained_model_name_or_path,
|
| 196 |
-
"adapter_config.json",
|
| 197 |
-
_raise_exceptions_for_missing_entries=False,
|
| 198 |
-
**cache_kwargs,
|
| 199 |
-
)
|
| 200 |
-
if adapter_config is not None:
|
| 201 |
-
from peft import PeftModel
|
| 202 |
-
|
| 203 |
-
# Pass original repo ID to PEFT, let it handle caching
|
| 204 |
-
model.language_model = PeftModel.from_pretrained(
|
| 205 |
-
model.language_model, pretrained_model_name_or_path, is_trainable=False
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
return model
|
| 209 |
finally:
|
| 210 |
cls._is_loading_from_pretrained = False
|
|
@@ -728,14 +711,14 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 728 |
thread.start()
|
| 729 |
|
| 730 |
# Yield tokens as they're generated, filtering out <think>...</think> blocks
|
| 731 |
-
#
|
| 732 |
-
in_think_block =
|
| 733 |
buffer = ""
|
| 734 |
|
| 735 |
for text in streamer:
|
| 736 |
buffer += text
|
| 737 |
|
| 738 |
-
# Check for think block start (in case model outputs
|
| 739 |
while "<think>" in buffer:
|
| 740 |
in_think_block = True
|
| 741 |
# Yield any text before <think>
|
|
|
|
| 51 |
raise ValueError(f"mask_length must be >= 1, got {mask_length}")
|
| 52 |
|
| 53 |
if mask_length > sequence_length:
|
| 54 |
+
raise ValueError(f"mask_length {mask_length} must be <= sequence_length {sequence_length}")
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# Compute number of masked spans per sample
|
| 57 |
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
|
|
|
|
| 188 |
state_dict = load_file(model_file)
|
| 189 |
model.load_state_dict(state_dict, strict=False)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
return model
|
| 192 |
finally:
|
| 193 |
cls._is_loading_from_pretrained = False
|
|
|
|
| 711 |
thread.start()
|
| 712 |
|
| 713 |
# Yield tokens as they're generated, filtering out <think>...</think> blocks
|
| 714 |
+
# Start assuming no think block - only filter when we see <think>
|
| 715 |
+
in_think_block = False
|
| 716 |
buffer = ""
|
| 717 |
|
| 718 |
for text in streamer:
|
| 719 |
buffer += text
|
| 720 |
|
| 721 |
+
# Check for think block start (in case model outputs think blocks)
|
| 722 |
while "<think>" in buffer:
|
| 723 |
in_think_block = True
|
| 724 |
# Yield any text before <think>
|
asr_pipeline.py
CHANGED
|
@@ -507,6 +507,4 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 507 |
break
|
| 508 |
|
| 509 |
# 3. STRIP WHITESPACE
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
return text
|
|
|
|
| 507 |
break
|
| 508 |
|
| 509 |
# 3. STRIP WHITESPACE
|
| 510 |
+
return re.sub(r"\s+", " ", text).strip()
|
|
|
|
|
|
asr_processing.py
CHANGED
|
@@ -106,7 +106,7 @@ class ASRProcessor(ProcessorMixin):
|
|
| 106 |
input_ids = tokenized
|
| 107 |
else:
|
| 108 |
# BatchEncoding or dict-like object
|
| 109 |
-
input_ids = tokenized
|
| 110 |
|
| 111 |
if input_ids.dim() == 1:
|
| 112 |
input_ids = input_ids.unsqueeze(0)
|
|
|
|
| 106 |
input_ids = tokenized
|
| 107 |
else:
|
| 108 |
# BatchEncoding or dict-like object
|
| 109 |
+
input_ids = tokenized.get("input_ids", tokenized.input_ids)
|
| 110 |
|
| 111 |
if input_ids.dim() == 1:
|
| 112 |
input_ids = input_ids.unsqueeze(0)
|