Update custom model files, README, and requirements
Browse files- asr_modeling.py +6 -13
- asr_pipeline.py +4 -0
asr_modeling.py
CHANGED
|
@@ -673,6 +673,12 @@ class ASRModel(PreTrainedModel):
|
|
| 673 |
|
| 674 |
num_audio_tokens = audio_embeds.shape[1]
|
| 675 |
expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
|
| 677 |
total_seq_len = inputs_embeds.shape[1]
|
| 678 |
attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
|
|
@@ -700,14 +706,8 @@ class ASRModel(PreTrainedModel):
|
|
| 700 |
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 701 |
prompt_length = expanded_prompt_ids.shape[1]
|
| 702 |
|
| 703 |
-
# Debug: Compare with streaming version
|
| 704 |
-
import sys
|
| 705 |
-
print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
|
| 706 |
-
print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 707 |
-
|
| 708 |
# Generate with or without streamer
|
| 709 |
if streamer is not None:
|
| 710 |
-
print(f"DEBUG generate: Using streamer", file=sys.stderr)
|
| 711 |
generated_ids = self.decoder.generate(
|
| 712 |
input_ids=expanded_prompt_ids,
|
| 713 |
inputs_embeds=inputs_embeds,
|
|
@@ -715,20 +715,13 @@ class ASRModel(PreTrainedModel):
|
|
| 715 |
streamer=streamer,
|
| 716 |
**generate_kwargs,
|
| 717 |
)
|
| 718 |
-
# Debug what was generated
|
| 719 |
-
generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
|
| 720 |
-
print(f"DEBUG generate with streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
|
| 721 |
else:
|
| 722 |
-
print(f"DEBUG generate: No streamer", file=sys.stderr)
|
| 723 |
generated_ids = self.decoder.generate(
|
| 724 |
input_ids=expanded_prompt_ids,
|
| 725 |
inputs_embeds=inputs_embeds,
|
| 726 |
attention_mask=attention_mask,
|
| 727 |
**generate_kwargs,
|
| 728 |
)
|
| 729 |
-
# Debug what was generated
|
| 730 |
-
generated_text = self.tokenizer.decode(generated_ids[0, prompt_length:], skip_special_tokens=True)
|
| 731 |
-
print(f"DEBUG generate without streamer: Generated text: {generated_text[:100]}", file=sys.stderr)
|
| 732 |
|
| 733 |
return generated_ids[:, prompt_length:]
|
| 734 |
|
|
|
|
| 673 |
|
| 674 |
num_audio_tokens = audio_embeds.shape[1]
|
| 675 |
expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
|
| 676 |
+
|
| 677 |
+
# Debug: Show what prompt we built
|
| 678 |
+
import sys
|
| 679 |
+
prompt_text = self.tokenizer.decode(expanded_prompt_ids[0], skip_special_tokens=False)
|
| 680 |
+
print(f"DEBUG generate: Built prompt: {prompt_text[:200]}", file=sys.stderr)
|
| 681 |
+
|
| 682 |
inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
|
| 683 |
total_seq_len = inputs_embeds.shape[1]
|
| 684 |
attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
|
|
|
|
| 706 |
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 707 |
prompt_length = expanded_prompt_ids.shape[1]
|
| 708 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
# Generate with or without streamer
|
| 710 |
if streamer is not None:
|
|
|
|
| 711 |
generated_ids = self.decoder.generate(
|
| 712 |
input_ids=expanded_prompt_ids,
|
| 713 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 715 |
streamer=streamer,
|
| 716 |
**generate_kwargs,
|
| 717 |
)
|
|
|
|
|
|
|
|
|
|
| 718 |
else:
|
|
|
|
| 719 |
generated_ids = self.decoder.generate(
|
| 720 |
input_ids=expanded_prompt_ids,
|
| 721 |
inputs_embeds=inputs_embeds,
|
| 722 |
attention_mask=attention_mask,
|
| 723 |
**generate_kwargs,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
| 725 |
|
| 726 |
return generated_ids[:, prompt_length:]
|
| 727 |
|
asr_pipeline.py
CHANGED
|
@@ -219,6 +219,10 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 219 |
generate_kwargs.setdefault("eos_token_id", im_end_id)
|
| 220 |
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# Pass the appropriate input type to generate
|
| 223 |
if is_whisper:
|
| 224 |
# Whisper model - use input_features
|
|
|
|
| 219 |
generate_kwargs.setdefault("eos_token_id", im_end_id)
|
| 220 |
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
|
| 221 |
|
| 222 |
+
# Debug: Log what we're passing to generate
|
| 223 |
+
import sys
|
| 224 |
+
print(f"DEBUG _forward: task={task}, system_prompt={self.model.config.system_prompt}", file=sys.stderr)
|
| 225 |
+
|
| 226 |
# Pass the appropriate input type to generate
|
| 227 |
if is_whisper:
|
| 228 |
# Whisper model - use input_features
|