Update custom model files, README, and requirements
Browse files- asr_modeling.py +30 -159
asr_modeling.py
CHANGED
|
@@ -616,6 +616,7 @@ class ASRModel(PreTrainedModel):
|
|
| 616 |
system_prompt: Optional[str] = None,
|
| 617 |
user_prompt: Optional[str] = None,
|
| 618 |
task: Optional[str] = None,
|
|
|
|
| 619 |
**generate_kwargs,
|
| 620 |
) -> Union[
|
| 621 |
torch.Tensor,
|
|
@@ -707,6 +708,10 @@ class ASRModel(PreTrainedModel):
|
|
| 707 |
print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
|
| 708 |
print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
generated_ids = self.decoder.generate(
|
| 711 |
input_ids=expanded_prompt_ids,
|
| 712 |
inputs_embeds=inputs_embeds,
|
|
@@ -724,157 +729,11 @@ class ASRModel(PreTrainedModel):
|
|
| 724 |
system_prompt: Optional[str] = None,
|
| 725 |
user_prompt: Optional[str] = None,
|
| 726 |
task: Optional[str] = None,
|
| 727 |
-
max_new_tokens: Optional[int] = None,
|
| 728 |
-
temperature: Optional[float] = None,
|
| 729 |
**generate_kwargs,
|
| 730 |
) -> Generator[Union[StreamChunk, StreamStats], None, None]:
|
| 731 |
"""
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
Args:
|
| 735 |
-
input_values: Audio input tensor for non-Whisper models
|
| 736 |
-
input_features: Audio input tensor for Whisper models
|
| 737 |
-
system_prompt: System prompt override
|
| 738 |
-
user_prompt: User prompt override
|
| 739 |
-
task: Task type (transcribe, describe, emotion, continue)
|
| 740 |
-
max_new_tokens: Maximum tokens to generate
|
| 741 |
-
temperature: Sampling temperature
|
| 742 |
-
**generate_kwargs: Additional generation parameters
|
| 743 |
-
|
| 744 |
-
Yields:
|
| 745 |
-
StreamChunk: Text chunks as they're generated
|
| 746 |
-
StreamStats: Final statistics (input_tokens, output_tokens)
|
| 747 |
"""
|
| 748 |
-
audio_inputs = input_values if input_values is not None else input_features
|
| 749 |
-
if audio_inputs is None:
|
| 750 |
-
raise ValueError("input_values or input_features must be provided for generation")
|
| 751 |
-
|
| 752 |
-
# Debug: Check audio inputs
|
| 753 |
-
import sys
|
| 754 |
-
print(f"DEBUG generate_stream: audio_inputs shape={audio_inputs.shape if audio_inputs is not None else None}", file=sys.stderr)
|
| 755 |
-
print(f"DEBUG generate_stream: audio_inputs type={type(audio_inputs)}", file=sys.stderr)
|
| 756 |
-
|
| 757 |
-
# Encode audio once and prepare prompt
|
| 758 |
-
audio_embeds = self._encode_audio(audio_inputs)
|
| 759 |
-
batch_size = audio_embeds.shape[0]
|
| 760 |
-
device = audio_embeds.device
|
| 761 |
-
|
| 762 |
-
if batch_size > 1:
|
| 763 |
-
raise ValueError("Streaming generation only supports batch_size=1")
|
| 764 |
-
|
| 765 |
-
if system_prompt is None:
|
| 766 |
-
system_prompt = self.system_prompt
|
| 767 |
-
|
| 768 |
-
if user_prompt is None:
|
| 769 |
-
user_prompt = (
|
| 770 |
-
self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
|
| 771 |
-
or "Transcribe: <audio>"
|
| 772 |
-
)
|
| 773 |
-
|
| 774 |
-
messages = []
|
| 775 |
-
if system_prompt:
|
| 776 |
-
messages.append({"role": "system", "content": system_prompt})
|
| 777 |
-
messages.append({"role": "user", "content": user_prompt})
|
| 778 |
-
|
| 779 |
-
prompt_ids = self.tokenizer.apply_chat_template(
|
| 780 |
-
messages,
|
| 781 |
-
tokenize=True,
|
| 782 |
-
add_generation_prompt=True,
|
| 783 |
-
return_tensors="pt",
|
| 784 |
-
enable_thinking=False,
|
| 785 |
-
).to(device)
|
| 786 |
-
|
| 787 |
-
if len(prompt_ids.shape) == 1:
|
| 788 |
-
prompt_ids = prompt_ids.unsqueeze(0)
|
| 789 |
-
|
| 790 |
-
if not (prompt_ids == self.audio_token_id).any():
|
| 791 |
-
raise ValueError("Audio token <audio> not found in prompt")
|
| 792 |
-
|
| 793 |
-
num_audio_tokens = audio_embeds.shape[1]
|
| 794 |
-
expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
|
| 795 |
-
inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
|
| 796 |
-
input_token_count = expanded_prompt_ids.shape[1]
|
| 797 |
-
|
| 798 |
-
attention_mask = torch.ones(
|
| 799 |
-
batch_size, input_token_count, dtype=torch.long, device=device
|
| 800 |
-
)
|
| 801 |
-
|
| 802 |
-
# Set up generation parameters from config (same as non-streaming generate)
|
| 803 |
-
config_params = [
|
| 804 |
-
"max_new_tokens",
|
| 805 |
-
"min_new_tokens",
|
| 806 |
-
"num_beams",
|
| 807 |
-
"do_sample",
|
| 808 |
-
"temperature",
|
| 809 |
-
"top_k",
|
| 810 |
-
"top_p",
|
| 811 |
-
"repetition_penalty",
|
| 812 |
-
"length_penalty",
|
| 813 |
-
"no_repeat_ngram_size",
|
| 814 |
-
"early_stopping",
|
| 815 |
-
]
|
| 816 |
-
for param in config_params:
|
| 817 |
-
if hasattr(self.config, param) and getattr(self.config, param) is not None:
|
| 818 |
-
generate_kwargs.setdefault(param, getattr(self.config, param))
|
| 819 |
-
|
| 820 |
-
# Override with explicit parameters if provided
|
| 821 |
-
if max_new_tokens is not None:
|
| 822 |
-
generate_kwargs["max_new_tokens"] = max_new_tokens
|
| 823 |
-
|
| 824 |
-
if temperature is not None:
|
| 825 |
-
generate_kwargs["temperature"] = temperature
|
| 826 |
-
generate_kwargs["do_sample"] = True
|
| 827 |
-
|
| 828 |
-
generate_kwargs.setdefault("use_cache", True)
|
| 829 |
-
generate_kwargs.setdefault(
|
| 830 |
-
"eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 831 |
-
)
|
| 832 |
-
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 833 |
-
|
| 834 |
-
# Debug: Check if audio embeds are in inputs_embeds
|
| 835 |
-
import sys
|
| 836 |
-
print(f"DEBUG generate_stream: task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
|
| 837 |
-
print(f"DEBUG generate_stream: inputs_embeds shape={inputs_embeds.shape}", file=sys.stderr)
|
| 838 |
-
print(f"DEBUG generate_stream: expanded_prompt_ids shape={expanded_prompt_ids.shape}", file=sys.stderr)
|
| 839 |
-
print(f"DEBUG generate_stream: audio_embeds shape={audio_embeds.shape}", file=sys.stderr)
|
| 840 |
-
print(f"DEBUG generate_stream: num_audio_tokens={num_audio_tokens}", file=sys.stderr)
|
| 841 |
-
print(f"DEBUG generate_stream: generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 842 |
-
|
| 843 |
-
# Debug: Check devices and values
|
| 844 |
-
print(f"DEBUG: inputs_embeds device={inputs_embeds.device}", file=sys.stderr)
|
| 845 |
-
print(f"DEBUG: expanded_prompt_ids device={expanded_prompt_ids.device}", file=sys.stderr)
|
| 846 |
-
print(f"DEBUG: attention_mask device={attention_mask.device}", file=sys.stderr)
|
| 847 |
-
print(f"DEBUG: decoder device={next(self.decoder.parameters()).device}", file=sys.stderr)
|
| 848 |
-
|
| 849 |
-
# Check if audio embeddings are non-zero
|
| 850 |
-
audio_mask = (expanded_prompt_ids == self.audio_token_id)
|
| 851 |
-
print(f"DEBUG: audio_mask sum={audio_mask.sum().item()} (should be {num_audio_tokens})", file=sys.stderr)
|
| 852 |
-
|
| 853 |
-
# Check a sample of the embeddings where audio should be
|
| 854 |
-
audio_positions = torch.where(audio_mask[0])[0]
|
| 855 |
-
if len(audio_positions) > 0:
|
| 856 |
-
sample_pos = audio_positions[0].item()
|
| 857 |
-
print(f"DEBUG: Sample audio embed at pos {sample_pos}: mean={inputs_embeds[0, sample_pos].mean().item():.4f}, std={inputs_embeds[0, sample_pos].std().item():.4f}", file=sys.stderr)
|
| 858 |
-
|
| 859 |
-
# Test: Try without threading first to see if that's the issue
|
| 860 |
-
print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
|
| 861 |
-
print(f"DEBUG: input_token_count (prompt length) = {input_token_count}", file=sys.stderr)
|
| 862 |
-
|
| 863 |
-
test_output = self.decoder.generate(
|
| 864 |
-
input_ids=expanded_prompt_ids,
|
| 865 |
-
inputs_embeds=inputs_embeds,
|
| 866 |
-
attention_mask=attention_mask,
|
| 867 |
-
max_new_tokens=10, # Just generate a few tokens to test
|
| 868 |
-
**{k: v for k, v in generate_kwargs.items() if k != 'max_new_tokens'}
|
| 869 |
-
)
|
| 870 |
-
|
| 871 |
-
# Debug the output
|
| 872 |
-
full_text = self.tokenizer.decode(test_output[0], skip_special_tokens=True)
|
| 873 |
-
print(f"DEBUG: Full output text: {full_text}", file=sys.stderr)
|
| 874 |
-
|
| 875 |
-
test_text = self.tokenizer.decode(test_output[0, input_token_count:], skip_special_tokens=True)
|
| 876 |
-
print(f"DEBUG: Non-threaded test output (after removing prompt): {test_text}", file=sys.stderr)
|
| 877 |
-
|
| 878 |
# Set up the streamer
|
| 879 |
streamer = TextIteratorStreamer(
|
| 880 |
self.tokenizer,
|
|
@@ -882,13 +741,26 @@ class ASRModel(PreTrainedModel):
|
|
| 882 |
skip_special_tokens=True
|
| 883 |
)
|
| 884 |
|
| 885 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
def generation_thread(future: futures.Future):
|
| 887 |
try:
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
|
|
|
|
|
|
|
|
|
| 892 |
streamer=streamer,
|
| 893 |
**generate_kwargs,
|
| 894 |
)
|
|
@@ -896,30 +768,29 @@ class ASRModel(PreTrainedModel):
|
|
| 896 |
except Exception as e:
|
| 897 |
future.set_exception(e)
|
| 898 |
|
| 899 |
-
future: futures.Future
|
| 900 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 901 |
thread.start()
|
| 902 |
|
| 903 |
# Stream the output
|
| 904 |
-
output_text = ""
|
| 905 |
output_token_count = 0
|
| 906 |
-
|
| 907 |
try:
|
| 908 |
for chunk in streamer:
|
| 909 |
if chunk:
|
| 910 |
-
output_text += chunk
|
| 911 |
output_token_count += 1
|
| 912 |
yield StreamChunk(chunk)
|
| 913 |
finally:
|
| 914 |
# Wait for generation to complete
|
| 915 |
thread.join()
|
| 916 |
-
|
| 917 |
-
# Check if there was an exception
|
| 918 |
if future.exception():
|
| 919 |
raise future.exception()
|
| 920 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
# Yield final statistics
|
| 922 |
-
yield StreamStats(
|
| 923 |
|
| 924 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 925 |
import shutil
|
|
|
|
| 616 |
system_prompt: Optional[str] = None,
|
| 617 |
user_prompt: Optional[str] = None,
|
| 618 |
task: Optional[str] = None,
|
| 619 |
+
streamer: Optional[TextIteratorStreamer] = None,
|
| 620 |
**generate_kwargs,
|
| 621 |
) -> Union[
|
| 622 |
torch.Tensor,
|
|
|
|
| 708 |
print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
|
| 709 |
print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
|
| 710 |
|
| 711 |
+
# Add streamer if provided
|
| 712 |
+
if streamer is not None:
|
| 713 |
+
generate_kwargs["streamer"] = streamer
|
| 714 |
+
|
| 715 |
generated_ids = self.decoder.generate(
|
| 716 |
input_ids=expanded_prompt_ids,
|
| 717 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 729 |
system_prompt: Optional[str] = None,
|
| 730 |
user_prompt: Optional[str] = None,
|
| 731 |
task: Optional[str] = None,
|
|
|
|
|
|
|
| 732 |
**generate_kwargs,
|
| 733 |
) -> Generator[Union[StreamChunk, StreamStats], None, None]:
|
| 734 |
"""
|
| 735 |
+
Stream generation by using the working generate() method with a TextIteratorStreamer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
# Set up the streamer
|
| 738 |
streamer = TextIteratorStreamer(
|
| 739 |
self.tokenizer,
|
|
|
|
| 741 |
skip_special_tokens=True
|
| 742 |
)
|
| 743 |
|
| 744 |
+
# Count prompt length for stats
|
| 745 |
+
# We need to encode just to get the prompt length
|
| 746 |
+
audio_inputs = input_values if input_values is not None else input_features
|
| 747 |
+
if audio_inputs is None:
|
| 748 |
+
raise ValueError("input_values or input_features must be provided")
|
| 749 |
+
|
| 750 |
+
# Simple way to get prompt length - just count audio tokens
|
| 751 |
+
import threading
|
| 752 |
+
from concurrent import futures
|
| 753 |
+
|
| 754 |
+
# Run generation in a thread with streamer
|
| 755 |
def generation_thread(future: futures.Future):
|
| 756 |
try:
|
| 757 |
+
# Just call the working generate method with the streamer
|
| 758 |
+
result = self.generate(
|
| 759 |
+
input_values=input_values,
|
| 760 |
+
input_features=input_features,
|
| 761 |
+
system_prompt=system_prompt,
|
| 762 |
+
user_prompt=user_prompt,
|
| 763 |
+
task=task,
|
| 764 |
streamer=streamer,
|
| 765 |
**generate_kwargs,
|
| 766 |
)
|
|
|
|
| 768 |
except Exception as e:
|
| 769 |
future.set_exception(e)
|
| 770 |
|
| 771 |
+
future: futures.Future = futures.Future()
|
| 772 |
thread = threading.Thread(target=generation_thread, args=(future,))
|
| 773 |
thread.start()
|
| 774 |
|
| 775 |
# Stream the output
|
|
|
|
| 776 |
output_token_count = 0
|
|
|
|
| 777 |
try:
|
| 778 |
for chunk in streamer:
|
| 779 |
if chunk:
|
|
|
|
| 780 |
output_token_count += 1
|
| 781 |
yield StreamChunk(chunk)
|
| 782 |
finally:
|
| 783 |
# Wait for generation to complete
|
| 784 |
thread.join()
|
|
|
|
|
|
|
| 785 |
if future.exception():
|
| 786 |
raise future.exception()
|
| 787 |
|
| 788 |
+
# For stats, estimate input tokens (we can't easily get exact count without duplicating work)
|
| 789 |
+
# Rough estimate: prompt is about 20 tokens + 750 audio tokens
|
| 790 |
+
estimated_input_tokens = 770
|
| 791 |
+
|
| 792 |
# Yield final statistics
|
| 793 |
+
yield StreamStats(estimated_input_tokens, output_token_count)
|
| 794 |
|
| 795 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 796 |
import shutil
|