mazesmazes commited on
Commit
2d11f35
·
verified ·
1 Parent(s): 3503ca3

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +16 -0
asr_modeling.py CHANGED
@@ -840,6 +840,22 @@ class ASRModel(PreTrainedModel):
840
  print(f"DEBUG generate_stream: num_audio_tokens={num_audio_tokens}", file=sys.stderr)
841
  print(f"DEBUG generate_stream: generate_kwargs={generate_kwargs}", file=sys.stderr)
842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
843
  # Test: Try without threading first to see if that's the issue
844
  print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
845
  test_output = self.decoder.generate(
 
840
  print(f"DEBUG generate_stream: num_audio_tokens={num_audio_tokens}", file=sys.stderr)
841
  print(f"DEBUG generate_stream: generate_kwargs={generate_kwargs}", file=sys.stderr)
842
 
843
+ # Debug: Check devices and values
844
+ print(f"DEBUG: inputs_embeds device={inputs_embeds.device}", file=sys.stderr)
845
+ print(f"DEBUG: expanded_prompt_ids device={expanded_prompt_ids.device}", file=sys.stderr)
846
+ print(f"DEBUG: attention_mask device={attention_mask.device}", file=sys.stderr)
847
+ print(f"DEBUG: decoder device={next(self.decoder.parameters()).device}", file=sys.stderr)
848
+
849
+ # Check if audio embeddings are non-zero
850
+ audio_mask = (expanded_prompt_ids == self.audio_token_id)
851
+ print(f"DEBUG: audio_mask sum={audio_mask.sum().item()} (should be {num_audio_tokens})", file=sys.stderr)
852
+
853
+ # Check a sample of the embeddings where audio should be
854
+ audio_positions = torch.where(audio_mask[0])[0]
855
+ if len(audio_positions) > 0:
856
+ sample_pos = audio_positions[0].item()
857
+ print(f"DEBUG: Sample audio embed at pos {sample_pos}: mean={inputs_embeds[0, sample_pos].mean().item():.4f}, std={inputs_embeds[0, sample_pos].std().item():.4f}", file=sys.stderr)
858
+
859
  # Test: Try without threading first to see if that's the issue
860
  print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
861
  test_output = self.decoder.generate(