mazesmazes commited on
Commit
ae56aa2
·
verified ·
1 Parent(s): b2944f4

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +79 -106
asr_modeling.py CHANGED
@@ -616,6 +616,7 @@ class ASRModel(PreTrainedModel):
616
  system_prompt: Optional[str] = None,
617
  user_prompt: Optional[str] = None,
618
  task: Optional[str] = None,
 
619
  **generate_kwargs,
620
  ) -> Union[
621
  torch.Tensor,
@@ -697,14 +698,27 @@ class ASRModel(PreTrainedModel):
697
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
698
  prompt_length = expanded_prompt_ids.shape[1]
699
 
700
- generated_ids = self.decoder.generate(
701
- input_ids=expanded_prompt_ids,
702
- inputs_embeds=inputs_embeds,
703
- attention_mask=attention_mask,
704
- **generate_kwargs,
705
- )
706
-
707
- return generated_ids[:, prompt_length:]
 
 
 
 
 
 
 
 
 
 
 
 
 
708
 
709
  @torch.no_grad()
710
  def generate_stream(
@@ -714,136 +728,95 @@ class ASRModel(PreTrainedModel):
714
  system_prompt: Optional[str] = None,
715
  user_prompt: Optional[str] = None,
716
  task: Optional[str] = None,
717
- max_new_tokens: Optional[int] = None,
718
- temperature: Optional[float] = None,
719
  **generate_kwargs,
720
  ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
721
  """
722
- Generate transcription in streaming mode, yielding text chunks as they're generated.
723
-
724
- Args:
725
- input_values: Audio input tensor for non-Whisper models
726
- input_features: Audio input tensor for Whisper models
727
- system_prompt: System prompt override
728
- user_prompt: User prompt override
729
- task: Task type (transcribe, describe, emotion, continue)
730
- max_new_tokens: Maximum tokens to generate
731
- temperature: Sampling temperature
732
- **generate_kwargs: Additional generation parameters
733
-
734
- Yields:
735
- StreamChunk: Text chunks as they're generated
736
- StreamStats: Final statistics (input_tokens, output_tokens)
737
  """
738
- audio_inputs = input_values if input_values is not None else input_features
739
- if audio_inputs is None:
740
- raise ValueError("input_values or input_features must be provided for generation")
741
-
742
- # Encode audio once and prepare prompt
743
- audio_embeds = self._encode_audio(audio_inputs)
744
- batch_size = audio_embeds.shape[0]
745
- device = audio_embeds.device
746
-
747
- if batch_size > 1:
748
- raise ValueError("Streaming generation only supports batch_size=1")
749
-
750
- if system_prompt is None:
751
- system_prompt = self.system_prompt
752
-
753
- if user_prompt is None:
754
- user_prompt = (
755
- self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
756
- or "Transcribe: <audio>"
757
- )
758
-
759
- messages = []
760
- if system_prompt:
761
- messages.append({"role": "system", "content": system_prompt})
762
- messages.append({"role": "user", "content": user_prompt})
763
-
764
- prompt_ids = self.tokenizer.apply_chat_template(
765
- messages,
766
- tokenize=True,
767
- add_generation_prompt=True,
768
- return_tensors="pt",
769
- enable_thinking=False,
770
- ).to(device)
771
-
772
- if len(prompt_ids.shape) == 1:
773
- prompt_ids = prompt_ids.unsqueeze(0)
774
-
775
- if not (prompt_ids == self.audio_token_id).any():
776
- raise ValueError("Audio token <audio> not found in prompt")
777
-
778
- num_audio_tokens = audio_embeds.shape[1]
779
- expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
780
- inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
781
- input_token_count = expanded_prompt_ids.shape[1]
782
-
783
- attention_mask = torch.ones(
784
- batch_size, input_token_count, dtype=torch.long, device=device
785
- )
786
-
787
- # Set up generation parameters
788
- if max_new_tokens is None:
789
- max_new_tokens = getattr(self.config, "max_new_tokens", 256)
790
-
791
- generate_kwargs.setdefault("max_new_tokens", max_new_tokens)
792
- generate_kwargs.setdefault("use_cache", True)
793
- generate_kwargs.setdefault(
794
- "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
795
- )
796
- generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
797
-
798
- if temperature is not None:
799
- generate_kwargs["temperature"] = temperature
800
- generate_kwargs.setdefault("do_sample", True)
801
-
802
- # Set up the streamer
803
  streamer = TextIteratorStreamer(
804
  self.tokenizer,
805
- skip_prompt=True,
806
- skip_special_tokens=True
 
807
  )
808
 
809
- # Generate in a separate thread
 
 
 
 
 
 
 
810
  def generation_thread(future: futures.Future):
811
  try:
812
- result = self.decoder.generate(
813
- input_ids=expanded_prompt_ids,
814
- inputs_embeds=inputs_embeds,
815
- attention_mask=attention_mask,
 
 
 
 
 
 
816
  streamer=streamer,
817
  **generate_kwargs,
818
  )
 
819
  future.set_result(result)
820
  except Exception as e:
 
821
  future.set_exception(e)
822
 
823
- future: futures.Future[torch.Tensor] = futures.Future()
824
  thread = threading.Thread(target=generation_thread, args=(future,))
825
  thread.start()
 
826
 
827
- # Stream the output
828
- output_text = ""
829
  output_token_count = 0
830
-
 
831
  try:
832
  for chunk in streamer:
833
- if chunk:
834
- output_text += chunk
835
  output_token_count += 1
836
  yield StreamChunk(chunk)
 
 
 
 
 
 
 
 
837
  finally:
838
  # Wait for generation to complete
839
  thread.join()
840
-
841
- # Check if there was an exception
842
  if future.exception():
843
  raise future.exception()
844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
845
  # Yield final statistics
846
- yield StreamStats(input_token_count, output_token_count)
847
 
848
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
849
  import shutil
 
616
  system_prompt: Optional[str] = None,
617
  user_prompt: Optional[str] = None,
618
  task: Optional[str] = None,
619
+ streamer: Optional[TextIteratorStreamer] = None,
620
  **generate_kwargs,
621
  ) -> Union[
622
  torch.Tensor,
 
698
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
699
  prompt_length = expanded_prompt_ids.shape[1]
700
 
701
+ # Generate with or without streamer
702
+ if streamer is not None:
703
+ generated_ids = self.decoder.generate(
704
+ input_ids=expanded_prompt_ids,
705
+ inputs_embeds=inputs_embeds,
706
+ attention_mask=attention_mask,
707
+ streamer=streamer,
708
+ **generate_kwargs,
709
+ )
710
+ # When using a streamer, return the full output (streamer will handle skipping prompt)
711
+ # The streamer needs the full sequence to properly identify what to skip
712
+ return generated_ids
713
+ else:
714
+ generated_ids = self.decoder.generate(
715
+ input_ids=expanded_prompt_ids,
716
+ inputs_embeds=inputs_embeds,
717
+ attention_mask=attention_mask,
718
+ **generate_kwargs,
719
+ )
720
+ # When not streaming, return only the new tokens (without prompt)
721
+ return generated_ids[:, prompt_length:]
722
 
723
  @torch.no_grad()
724
  def generate_stream(
 
728
  system_prompt: Optional[str] = None,
729
  user_prompt: Optional[str] = None,
730
  task: Optional[str] = None,
 
 
731
  **generate_kwargs,
732
  ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
733
  """
734
+ Stream generation by using the working generate() method with a TextIteratorStreamer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  """
736
+ # Set up the streamer - use skip_prompt=True like Ultravox
737
+ # The key is that when we return the full sequence from generate(),
738
+ # the streamer can properly identify and skip the prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  streamer = TextIteratorStreamer(
740
  self.tokenizer,
741
+ skip_prompt=True, # Skip the prompt tokens
742
+ skip_special_tokens=True,
743
+ timeout=30.0
744
  )
745
 
746
+ audio_inputs = input_values if input_values is not None else input_features
747
+ if audio_inputs is None:
748
+ raise ValueError("input_values or input_features must be provided")
749
+
750
+ import threading
751
+ from concurrent import futures
752
+
753
+ # Run generation in a thread with streamer
754
  def generation_thread(future: futures.Future):
755
  try:
756
+ import sys
757
+ print("DEBUG: Starting generation thread", file=sys.stderr)
758
+ # Call generate with the streamer
759
+ # Important: This now returns the FULL sequence when streaming
760
+ result = self.generate(
761
+ input_values=input_values,
762
+ input_features=input_features,
763
+ system_prompt=system_prompt,
764
+ user_prompt=user_prompt,
765
+ task=task,
766
  streamer=streamer,
767
  **generate_kwargs,
768
  )
769
+ print("DEBUG: Generation complete", file=sys.stderr)
770
  future.set_result(result)
771
  except Exception as e:
772
+ print(f"DEBUG: Generation error: {e}", file=sys.stderr)
773
  future.set_exception(e)
774
 
775
+ future: futures.Future = futures.Future()
776
  thread = threading.Thread(target=generation_thread, args=(future,))
777
  thread.start()
778
+ print("DEBUG: Thread started", file=sys.stderr)
779
 
780
+ # Stream the output - like Ultravox, just yield chunks as they come
 
781
  output_token_count = 0
782
+ import sys
783
+ print("DEBUG: Starting streaming iteration", file=sys.stderr)
784
  try:
785
  for chunk in streamer:
786
+ print(f"DEBUG: Got chunk: {repr(chunk)}", file=sys.stderr)
787
+ if chunk: # Only yield non-empty chunks
788
  output_token_count += 1
789
  yield StreamChunk(chunk)
790
+ except Exception as e:
791
+ # Check if it's the Empty exception from queue
792
+ if e.__class__.__name__ == "Empty":
793
+ # This happens when generation completes before we start iterating
794
+ pass
795
+ else:
796
+ # Re-raise other exceptions
797
+ raise
798
  finally:
799
  # Wait for generation to complete
800
  thread.join()
 
 
801
  if future.exception():
802
  raise future.exception()
803
 
804
+ # Debug: If no chunks were yielded, check what was generated
805
+ if output_token_count == 0:
806
+ import sys
807
+ result = future.result()
808
+ if result is not None:
809
+ # Note: result now includes the full sequence (including prompt)
810
+ # when streaming, so decode the full thing
811
+ decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
812
+ print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
813
+
814
+ # For stats, estimate input tokens (we can't easily get exact count without duplicating work)
815
+ # Rough estimate: prompt is about 20 tokens + 750 audio tokens
816
+ estimated_input_tokens = 770
817
+
818
  # Yield final statistics
819
+ yield StreamStats(estimated_input_tokens, output_token_count)
820
 
821
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
822
  import shutil