Training in progress - step 8000
Browse files- asr_modeling.py +106 -71
asr_modeling.py
CHANGED
|
@@ -616,7 +616,6 @@ class ASRModel(PreTrainedModel):
|
|
| 616 |
system_prompt: Optional[str] = None,
|
| 617 |
user_prompt: Optional[str] = None,
|
| 618 |
task: Optional[str] = None,
|
| 619 |
-
streamer: Optional[TextIteratorStreamer] = None,
|
| 620 |
**generate_kwargs,
|
| 621 |
) -> Union[
|
| 622 |
torch.Tensor,
|
|
@@ -698,27 +697,14 @@ class ASRModel(PreTrainedModel):
|
|
| 698 |
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 699 |
prompt_length = expanded_prompt_ids.shape[1]
|
| 700 |
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
)
|
| 710 |
-
# When using a streamer, return the full output (streamer will handle skipping prompt)
|
| 711 |
-
# The streamer needs the full sequence to properly identify what to skip
|
| 712 |
-
return generated_ids
|
| 713 |
-
else:
|
| 714 |
-
generated_ids = self.decoder.generate(
|
| 715 |
-
input_ids=expanded_prompt_ids,
|
| 716 |
-
inputs_embeds=inputs_embeds,
|
| 717 |
-
attention_mask=attention_mask,
|
| 718 |
-
**generate_kwargs,
|
| 719 |
-
)
|
| 720 |
-
# When not streaming, return only the new tokens (without prompt)
|
| 721 |
-
return generated_ids[:, prompt_length:]
|
| 722 |
|
| 723 |
@torch.no_grad()
|
| 724 |
def generate_stream(
|
|
@@ -728,39 +714,105 @@ class ASRModel(PreTrainedModel):
|
|
| 728 |
system_prompt: Optional[str] = None,
|
| 729 |
user_prompt: Optional[str] = None,
|
| 730 |
task: Optional[str] = None,
|
|
|
|
|
|
|
| 731 |
**generate_kwargs,
|
| 732 |
) -> Generator[Union[StreamChunk, StreamStats], None, None]:
|
| 733 |
"""
|
| 734 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
"""
|
| 736 |
-
# Set up the streamer - use skip_prompt=True like Ultravox
|
| 737 |
-
# The key is that when we return the full sequence from generate(),
|
| 738 |
-
# the streamer can properly identify and skip the prompt
|
| 739 |
-
streamer = TextIteratorStreamer(
|
| 740 |
-
self.tokenizer,
|
| 741 |
-
skip_prompt=True, # Skip the prompt tokens
|
| 742 |
-
skip_special_tokens=True,
|
| 743 |
-
timeout=30.0
|
| 744 |
-
)
|
| 745 |
-
|
| 746 |
audio_inputs = input_values if input_values is not None else input_features
|
| 747 |
if audio_inputs is None:
|
| 748 |
-
raise ValueError("input_values or input_features must be provided")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
|
| 750 |
-
|
| 751 |
-
|
| 752 |
|
| 753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
def generation_thread(future: futures.Future):
|
| 755 |
try:
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
input_features=input_features,
|
| 761 |
-
system_prompt=system_prompt,
|
| 762 |
-
user_prompt=user_prompt,
|
| 763 |
-
task=task,
|
| 764 |
streamer=streamer,
|
| 765 |
**generate_kwargs,
|
| 766 |
)
|
|
@@ -768,47 +820,30 @@ class ASRModel(PreTrainedModel):
|
|
| 768 |
except Exception as e:
|
| 769 |
future.set_exception(e)
|
| 770 |
|
| 771 |
-
future: futures.Future = futures.Future()
|
| 772 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 773 |
thread.start()
|
| 774 |
|
| 775 |
-
# Stream the output
|
|
|
|
| 776 |
output_token_count = 0
|
|
|
|
| 777 |
try:
|
| 778 |
for chunk in streamer:
|
| 779 |
-
if chunk:
|
|
|
|
| 780 |
output_token_count += 1
|
| 781 |
yield StreamChunk(chunk)
|
| 782 |
-
except Exception as e:
|
| 783 |
-
# Check if it's the Empty exception from queue
|
| 784 |
-
if e.__class__.__name__ == "Empty":
|
| 785 |
-
# This happens when generation completes before we start iterating
|
| 786 |
-
pass
|
| 787 |
-
else:
|
| 788 |
-
# Re-raise other exceptions
|
| 789 |
-
raise
|
| 790 |
finally:
|
| 791 |
# Wait for generation to complete
|
| 792 |
thread.join()
|
|
|
|
|
|
|
| 793 |
if future.exception():
|
| 794 |
raise future.exception()
|
| 795 |
|
| 796 |
-
# Debug: If no chunks were yielded, check what was generated
|
| 797 |
-
if output_token_count == 0:
|
| 798 |
-
import sys
|
| 799 |
-
result = future.result()
|
| 800 |
-
if result is not None:
|
| 801 |
-
# Note: result now includes the full sequence (including prompt)
|
| 802 |
-
# when streaming, so decode the full thing
|
| 803 |
-
decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
|
| 804 |
-
print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
|
| 805 |
-
|
| 806 |
-
# For stats, estimate input tokens (we can't easily get exact count without duplicating work)
|
| 807 |
-
# Rough estimate: prompt is about 20 tokens + 750 audio tokens
|
| 808 |
-
estimated_input_tokens = 770
|
| 809 |
-
|
| 810 |
# Yield final statistics
|
| 811 |
-
yield StreamStats(
|
| 812 |
|
| 813 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 814 |
import shutil
|
|
|
|
| 616 |
system_prompt: Optional[str] = None,
|
| 617 |
user_prompt: Optional[str] = None,
|
| 618 |
task: Optional[str] = None,
|
|
|
|
| 619 |
**generate_kwargs,
|
| 620 |
) -> Union[
|
| 621 |
torch.Tensor,
|
|
|
|
| 697 |
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 698 |
prompt_length = expanded_prompt_ids.shape[1]
|
| 699 |
|
| 700 |
+
generated_ids = self.decoder.generate(
|
| 701 |
+
input_ids=expanded_prompt_ids,
|
| 702 |
+
inputs_embeds=inputs_embeds,
|
| 703 |
+
attention_mask=attention_mask,
|
| 704 |
+
**generate_kwargs,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
return generated_ids[:, prompt_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
|
| 709 |
@torch.no_grad()
|
| 710 |
def generate_stream(
|
|
|
|
| 714 |
system_prompt: Optional[str] = None,
|
| 715 |
user_prompt: Optional[str] = None,
|
| 716 |
task: Optional[str] = None,
|
| 717 |
+
max_new_tokens: Optional[int] = None,
|
| 718 |
+
temperature: Optional[float] = None,
|
| 719 |
**generate_kwargs,
|
| 720 |
) -> Generator[Union[StreamChunk, StreamStats], None, None]:
|
| 721 |
"""
|
| 722 |
+
Generate transcription in streaming mode, yielding text chunks as they're generated.
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
input_values: Audio input tensor for non-Whisper models
|
| 726 |
+
input_features: Audio input tensor for Whisper models
|
| 727 |
+
system_prompt: System prompt override
|
| 728 |
+
user_prompt: User prompt override
|
| 729 |
+
task: Task type (transcribe, describe, emotion, continue)
|
| 730 |
+
max_new_tokens: Maximum tokens to generate
|
| 731 |
+
temperature: Sampling temperature
|
| 732 |
+
**generate_kwargs: Additional generation parameters
|
| 733 |
+
|
| 734 |
+
Yields:
|
| 735 |
+
StreamChunk: Text chunks as they're generated
|
| 736 |
+
StreamStats: Final statistics (input_tokens, output_tokens)
|
| 737 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
audio_inputs = input_values if input_values is not None else input_features
|
| 739 |
if audio_inputs is None:
|
| 740 |
+
raise ValueError("input_values or input_features must be provided for generation")
|
| 741 |
+
|
| 742 |
+
# Encode audio once and prepare prompt
|
| 743 |
+
audio_embeds = self._encode_audio(audio_inputs)
|
| 744 |
+
batch_size = audio_embeds.shape[0]
|
| 745 |
+
device = audio_embeds.device
|
| 746 |
|
| 747 |
+
if batch_size > 1:
|
| 748 |
+
raise ValueError("Streaming generation only supports batch_size=1")
|
| 749 |
|
| 750 |
+
if system_prompt is None:
|
| 751 |
+
system_prompt = self.system_prompt
|
| 752 |
+
|
| 753 |
+
if user_prompt is None:
|
| 754 |
+
user_prompt = (
|
| 755 |
+
self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
|
| 756 |
+
or "Transcribe: <audio>"
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
messages = []
|
| 760 |
+
if system_prompt:
|
| 761 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 762 |
+
messages.append({"role": "user", "content": user_prompt})
|
| 763 |
+
|
| 764 |
+
prompt_ids = self.tokenizer.apply_chat_template(
|
| 765 |
+
messages,
|
| 766 |
+
tokenize=True,
|
| 767 |
+
add_generation_prompt=True,
|
| 768 |
+
return_tensors="pt",
|
| 769 |
+
enable_thinking=False,
|
| 770 |
+
).to(device)
|
| 771 |
+
|
| 772 |
+
if len(prompt_ids.shape) == 1:
|
| 773 |
+
prompt_ids = prompt_ids.unsqueeze(0)
|
| 774 |
+
|
| 775 |
+
if not (prompt_ids == self.audio_token_id).any():
|
| 776 |
+
raise ValueError("Audio token <audio> not found in prompt")
|
| 777 |
+
|
| 778 |
+
num_audio_tokens = audio_embeds.shape[1]
|
| 779 |
+
expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
|
| 780 |
+
inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
|
| 781 |
+
input_token_count = expanded_prompt_ids.shape[1]
|
| 782 |
+
|
| 783 |
+
attention_mask = torch.ones(
|
| 784 |
+
batch_size, input_token_count, dtype=torch.long, device=device
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# Set up generation parameters
|
| 788 |
+
if max_new_tokens is None:
|
| 789 |
+
max_new_tokens = getattr(self.config, "max_new_tokens", 256)
|
| 790 |
+
|
| 791 |
+
generate_kwargs.setdefault("max_new_tokens", max_new_tokens)
|
| 792 |
+
generate_kwargs.setdefault("use_cache", True)
|
| 793 |
+
generate_kwargs.setdefault(
|
| 794 |
+
"eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 795 |
+
)
|
| 796 |
+
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 797 |
+
|
| 798 |
+
if temperature is not None:
|
| 799 |
+
generate_kwargs["temperature"] = temperature
|
| 800 |
+
generate_kwargs.setdefault("do_sample", True)
|
| 801 |
+
|
| 802 |
+
# Set up the streamer
|
| 803 |
+
streamer = TextIteratorStreamer(
|
| 804 |
+
self.tokenizer,
|
| 805 |
+
skip_prompt=True,
|
| 806 |
+
skip_special_tokens=True
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# Generate in a separate thread
|
| 810 |
def generation_thread(future: futures.Future):
|
| 811 |
try:
|
| 812 |
+
result = self.decoder.generate(
|
| 813 |
+
input_ids=expanded_prompt_ids,
|
| 814 |
+
inputs_embeds=inputs_embeds,
|
| 815 |
+
attention_mask=attention_mask,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
streamer=streamer,
|
| 817 |
**generate_kwargs,
|
| 818 |
)
|
|
|
|
| 820 |
except Exception as e:
|
| 821 |
future.set_exception(e)
|
| 822 |
|
| 823 |
+
future: futures.Future[torch.Tensor] = futures.Future()
|
| 824 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 825 |
thread.start()
|
| 826 |
|
| 827 |
+
# Stream the output
|
| 828 |
+
output_text = ""
|
| 829 |
output_token_count = 0
|
| 830 |
+
|
| 831 |
try:
|
| 832 |
for chunk in streamer:
|
| 833 |
+
if chunk:
|
| 834 |
+
output_text += chunk
|
| 835 |
output_token_count += 1
|
| 836 |
yield StreamChunk(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
finally:
|
| 838 |
# Wait for generation to complete
|
| 839 |
thread.join()
|
| 840 |
+
|
| 841 |
+
# Check if there was an exception
|
| 842 |
if future.exception():
|
| 843 |
raise future.exception()
|
| 844 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 845 |
# Yield final statistics
|
| 846 |
+
yield StreamStats(input_token_count, output_token_count)
|
| 847 |
|
| 848 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 849 |
import shutil
|