Update custom model files, README, and requirements
Browse files- asr_modeling.py +16 -0
asr_modeling.py
CHANGED
|
@@ -840,6 +840,22 @@ class ASRModel(PreTrainedModel):
|
|
| 840 |
print(f"DEBUG generate_stream: num_audio_tokens={num_audio_tokens}", file=sys.stderr)
|
| 841 |
print(f"DEBUG generate_stream: generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 842 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
# Test: Try without threading first to see if that's the issue
|
| 844 |
print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
|
| 845 |
test_output = self.decoder.generate(
|
|
|
|
| 840 |
print(f"DEBUG generate_stream: num_audio_tokens={num_audio_tokens}", file=sys.stderr)
|
| 841 |
print(f"DEBUG generate_stream: generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 842 |
|
| 843 |
+
# Debug: Check devices and values
|
| 844 |
+
print(f"DEBUG: inputs_embeds device={inputs_embeds.device}", file=sys.stderr)
|
| 845 |
+
print(f"DEBUG: expanded_prompt_ids device={expanded_prompt_ids.device}", file=sys.stderr)
|
| 846 |
+
print(f"DEBUG: attention_mask device={attention_mask.device}", file=sys.stderr)
|
| 847 |
+
print(f"DEBUG: decoder device={next(self.decoder.parameters()).device}", file=sys.stderr)
|
| 848 |
+
|
| 849 |
+
# Check if audio embeddings are non-zero
|
| 850 |
+
audio_mask = (expanded_prompt_ids == self.audio_token_id)
|
| 851 |
+
print(f"DEBUG: audio_mask sum={audio_mask.sum().item()} (should be {num_audio_tokens})", file=sys.stderr)
|
| 852 |
+
|
| 853 |
+
# Check a sample of the embeddings where audio should be
|
| 854 |
+
audio_positions = torch.where(audio_mask[0])[0]
|
| 855 |
+
if len(audio_positions) > 0:
|
| 856 |
+
sample_pos = audio_positions[0].item()
|
| 857 |
+
print(f"DEBUG: Sample audio embed at pos {sample_pos}: mean={inputs_embeds[0, sample_pos].mean().item():.4f}, std={inputs_embeds[0, sample_pos].std().item():.4f}", file=sys.stderr)
|
| 858 |
+
|
| 859 |
# Test: Try without threading first to see if that's the issue
|
| 860 |
print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
|
| 861 |
test_output = self.decoder.generate(
|