mazesmazes commited on
Commit
e393d4c
·
verified ·
1 Parent(s): 45f7226

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_modeling.py +6 -13
  2. asr_pipeline.py +4 -0
asr_modeling.py CHANGED
@@ -673,6 +673,12 @@ class ASRModel(PreTrainedModel):
673
 
674
  num_audio_tokens = audio_embeds.shape[1]
675
  expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
 
 
 
 
 
 
676
  inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
677
  total_seq_len = inputs_embeds.shape[1]
678
  attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
@@ -700,14 +706,8 @@ class ASRModel(PreTrainedModel):
700
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
701
  prompt_length = expanded_prompt_ids.shape[1]
702
 
703
- # Debug: Compare with streaming version
704
- import sys
705
- print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
706
- print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
707
-
708
  # Generate with or without streamer
709
  if streamer is not None:
710
- print(f"DEBUG generate: Using streamer", file=sys.stderr)
711
  generated_ids = self.decoder.generate(
712
  input_ids=expanded_prompt_ids,
713
  inputs_embeds=inputs_embeds,
@@ -715,20 +715,13 @@ class ASRModel(PreTrainedModel):
715
  streamer=streamer,
716
  **generate_kwargs,
717
  )
718
- # Debug what was generated
719
- generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
720
- print(f"DEBUG generate with streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
721
  else:
722
- print(f"DEBUG generate: No streamer", file=sys.stderr)
723
  generated_ids = self.decoder.generate(
724
  input_ids=expanded_prompt_ids,
725
  inputs_embeds=inputs_embeds,
726
  attention_mask=attention_mask,
727
  **generate_kwargs,
728
  )
729
- # Debug what was generated
730
- generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
731
- print(f"DEBUG generate without streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
732
 
733
  return generated_ids[:, prompt_length:]
734
 
 
673
 
674
  num_audio_tokens = audio_embeds.shape[1]
675
  expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
676
+
677
+ # Debug: Show what prompt we built
678
+ import sys
679
+ prompt_text = self.tokenizer.decode(expanded_prompt_ids[0], skip_special_tokens=False)
680
+ print(f"DEBUG generate: Built prompt: {prompt_text[:200]}", file=sys.stderr)
681
+
682
  inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
683
  total_seq_len = inputs_embeds.shape[1]
684
  attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
 
706
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
707
  prompt_length = expanded_prompt_ids.shape[1]
708
 
 
 
 
 
 
709
  # Generate with or without streamer
710
  if streamer is not None:
 
711
  generated_ids = self.decoder.generate(
712
  input_ids=expanded_prompt_ids,
713
  inputs_embeds=inputs_embeds,
 
715
  streamer=streamer,
716
  **generate_kwargs,
717
  )
 
 
 
718
  else:
 
719
  generated_ids = self.decoder.generate(
720
  input_ids=expanded_prompt_ids,
721
  inputs_embeds=inputs_embeds,
722
  attention_mask=attention_mask,
723
  **generate_kwargs,
724
  )
 
 
 
725
 
726
  return generated_ids[:, prompt_length:]
727
 
asr_pipeline.py CHANGED
@@ -219,6 +219,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
219
  generate_kwargs.setdefault("eos_token_id", im_end_id)
220
  generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
221
 
 
 
 
 
222
  # Pass the appropriate input type to generate
223
  if is_whisper:
224
  # Whisper model - use input_features
 
219
  generate_kwargs.setdefault("eos_token_id", im_end_id)
220
  generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
221
 
222
+ # Debug: Log what we're passing to generate
223
+ import sys
224
+ print(f"DEBUG _forward: task={task}, system_prompt={self.model.config.system_prompt}", file=sys.stderr)
225
+
226
  # Pass the appropriate input type to generate
227
  if is_whisper:
228
  # Whisper model - use input_features