mazesmazes commited on
Commit
898105b
·
verified ·
1 Parent(s): d3de5fa

Training in progress - step 1000

Browse files
Files changed (3) hide show
  1. asr_modeling.py +4 -21
  2. asr_pipeline.py +1 -3
  3. asr_processing.py +1 -1
asr_modeling.py CHANGED
@@ -51,9 +51,7 @@ def _compute_mask_indices(
51
  raise ValueError(f"mask_length must be >= 1, got {mask_length}")
52
 
53
  if mask_length > sequence_length:
54
- raise ValueError(
55
- f"mask_length {mask_length} must be <= sequence_length {sequence_length}"
56
- )
57
 
58
  # Compute number of masked spans per sample
59
  num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
@@ -190,21 +188,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
190
  state_dict = load_file(model_file)
191
  model.load_state_dict(state_dict, strict=False)
192
 
193
- # Load LoRA adapter if present
194
- adapter_config = cached_file(
195
- pretrained_model_name_or_path,
196
- "adapter_config.json",
197
- _raise_exceptions_for_missing_entries=False,
198
- **cache_kwargs,
199
- )
200
- if adapter_config is not None:
201
- from peft import PeftModel
202
-
203
- # Pass original repo ID to PEFT, let it handle caching
204
- model.language_model = PeftModel.from_pretrained(
205
- model.language_model, pretrained_model_name_or_path, is_trainable=False
206
- )
207
-
208
  return model
209
  finally:
210
  cls._is_loading_from_pretrained = False
@@ -728,14 +711,14 @@ class ASRModel(PreTrainedModel, GenerationMixin):
728
  thread.start()
729
 
730
  # Yield tokens as they're generated, filtering out <think>...</think> blocks
731
- # SmolLM3 always starts in thinking mode, so assume we're in a think block
732
- in_think_block = True
733
  buffer = ""
734
 
735
  for text in streamer:
736
  buffer += text
737
 
738
- # Check for think block start (in case model outputs multiple think blocks)
739
  while "<think>" in buffer:
740
  in_think_block = True
741
  # Yield any text before <think>
 
51
  raise ValueError(f"mask_length must be >= 1, got {mask_length}")
52
 
53
  if mask_length > sequence_length:
54
+ raise ValueError(f"mask_length {mask_length} must be <= sequence_length {sequence_length}")
 
 
55
 
56
  # Compute number of masked spans per sample
57
  num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item())
 
188
  state_dict = load_file(model_file)
189
  model.load_state_dict(state_dict, strict=False)
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  return model
192
  finally:
193
  cls._is_loading_from_pretrained = False
 
711
  thread.start()
712
 
713
  # Yield tokens as they're generated, filtering out <think>...</think> blocks
714
+ # Start assuming no think block - only filter when we see <think>
715
+ in_think_block = False
716
  buffer = ""
717
 
718
  for text in streamer:
719
  buffer += text
720
 
721
+ # Check for think block start (in case model outputs think blocks)
722
  while "<think>" in buffer:
723
  in_think_block = True
724
  # Yield any text before <think>
asr_pipeline.py CHANGED
@@ -507,6 +507,4 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
507
  break
508
 
509
  # 3. STRIP WHITESPACE
510
- text = re.sub(r'\s+', ' ', text).strip()
511
-
512
- return text
 
507
  break
508
 
509
  # 3. STRIP WHITESPACE
510
+ return re.sub(r"\s+", " ", text).strip()
 
 
asr_processing.py CHANGED
@@ -106,7 +106,7 @@ class ASRProcessor(ProcessorMixin):
106
  input_ids = tokenized
107
  else:
108
  # BatchEncoding or dict-like object
109
- input_ids = tokenized["input_ids"] if "input_ids" in tokenized else tokenized.input_ids
110
 
111
  if input_ids.dim() == 1:
112
  input_ids = input_ids.unsqueeze(0)
 
106
  input_ids = tokenized
107
  else:
108
  # BatchEncoding or dict-like object
109
+ input_ids = tokenized.get("input_ids", tokenized.input_ids)
110
 
111
  if input_ids.dim() == 1:
112
  input_ids = input_ids.unsqueeze(0)