mazesmazes commited on
Commit
b2944f4
·
verified ·
1 Parent(s): 3e0fdca

Training in progress - step 8000

Browse files
Files changed (1) hide show
  1. asr_modeling.py +106 -71
asr_modeling.py CHANGED
@@ -616,7 +616,6 @@ class ASRModel(PreTrainedModel):
616
  system_prompt: Optional[str] = None,
617
  user_prompt: Optional[str] = None,
618
  task: Optional[str] = None,
619
- streamer: Optional[TextIteratorStreamer] = None,
620
  **generate_kwargs,
621
  ) -> Union[
622
  torch.Tensor,
@@ -698,27 +697,14 @@ class ASRModel(PreTrainedModel):
698
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
699
  prompt_length = expanded_prompt_ids.shape[1]
700
 
701
- # Generate with or without streamer
702
- if streamer is not None:
703
- generated_ids = self.decoder.generate(
704
- input_ids=expanded_prompt_ids,
705
- inputs_embeds=inputs_embeds,
706
- attention_mask=attention_mask,
707
- streamer=streamer,
708
- **generate_kwargs,
709
- )
710
- # When using a streamer, return the full output (streamer will handle skipping prompt)
711
- # The streamer needs the full sequence to properly identify what to skip
712
- return generated_ids
713
- else:
714
- generated_ids = self.decoder.generate(
715
- input_ids=expanded_prompt_ids,
716
- inputs_embeds=inputs_embeds,
717
- attention_mask=attention_mask,
718
- **generate_kwargs,
719
- )
720
- # When not streaming, return only the new tokens (without prompt)
721
- return generated_ids[:, prompt_length:]
722
 
723
  @torch.no_grad()
724
  def generate_stream(
@@ -728,39 +714,105 @@ class ASRModel(PreTrainedModel):
728
  system_prompt: Optional[str] = None,
729
  user_prompt: Optional[str] = None,
730
  task: Optional[str] = None,
 
 
731
  **generate_kwargs,
732
  ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
733
  """
734
- Stream generation by using the working generate() method with a TextIteratorStreamer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  """
736
- # Set up the streamer - use skip_prompt=True like Ultravox
737
- # The key is that when we return the full sequence from generate(),
738
- # the streamer can properly identify and skip the prompt
739
- streamer = TextIteratorStreamer(
740
- self.tokenizer,
741
- skip_prompt=True, # Skip the prompt tokens
742
- skip_special_tokens=True,
743
- timeout=30.0
744
- )
745
-
746
  audio_inputs = input_values if input_values is not None else input_features
747
  if audio_inputs is None:
748
- raise ValueError("input_values or input_features must be provided")
 
 
 
 
 
749
 
750
- import threading
751
- from concurrent import futures
752
 
753
- # Run generation in a thread with streamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  def generation_thread(future: futures.Future):
755
  try:
756
- # Call generate with the streamer
757
- # Important: This now returns the FULL sequence when streaming
758
- result = self.generate(
759
- input_values=input_values,
760
- input_features=input_features,
761
- system_prompt=system_prompt,
762
- user_prompt=user_prompt,
763
- task=task,
764
  streamer=streamer,
765
  **generate_kwargs,
766
  )
@@ -768,47 +820,30 @@ class ASRModel(PreTrainedModel):
768
  except Exception as e:
769
  future.set_exception(e)
770
 
771
- future: futures.Future = futures.Future()
772
  thread = threading.Thread(target=generation_thread, args=(future,))
773
  thread.start()
774
 
775
- # Stream the output - like Ultravox, just yield chunks as they come
 
776
  output_token_count = 0
 
777
  try:
778
  for chunk in streamer:
779
- if chunk: # Only yield non-empty chunks
 
780
  output_token_count += 1
781
  yield StreamChunk(chunk)
782
- except Exception as e:
783
- # Check if it's the Empty exception from queue
784
- if e.__class__.__name__ == "Empty":
785
- # This happens when generation completes before we start iterating
786
- pass
787
- else:
788
- # Re-raise other exceptions
789
- raise
790
  finally:
791
  # Wait for generation to complete
792
  thread.join()
 
 
793
  if future.exception():
794
  raise future.exception()
795
 
796
- # Debug: If no chunks were yielded, check what was generated
797
- if output_token_count == 0:
798
- import sys
799
- result = future.result()
800
- if result is not None:
801
- # Note: result now includes the full sequence (including prompt)
802
- # when streaming, so decode the full thing
803
- decoded = self.tokenizer.decode(result[0], skip_special_tokens=True)
804
- print(f"DEBUG: No chunks yielded but generated: {decoded}", file=sys.stderr)
805
-
806
- # For stats, estimate input tokens (we can't easily get exact count without duplicating work)
807
- # Rough estimate: prompt is about 20 tokens + 750 audio tokens
808
- estimated_input_tokens = 770
809
-
810
  # Yield final statistics
811
- yield StreamStats(estimated_input_tokens, output_token_count)
812
 
813
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
814
  import shutil
 
616
  system_prompt: Optional[str] = None,
617
  user_prompt: Optional[str] = None,
618
  task: Optional[str] = None,
 
619
  **generate_kwargs,
620
  ) -> Union[
621
  torch.Tensor,
 
697
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
698
  prompt_length = expanded_prompt_ids.shape[1]
699
 
700
+ generated_ids = self.decoder.generate(
701
+ input_ids=expanded_prompt_ids,
702
+ inputs_embeds=inputs_embeds,
703
+ attention_mask=attention_mask,
704
+ **generate_kwargs,
705
+ )
706
+
707
+ return generated_ids[:, prompt_length:]
 
 
 
 
 
 
 
 
 
 
 
 
 
708
 
709
  @torch.no_grad()
710
  def generate_stream(
 
714
  system_prompt: Optional[str] = None,
715
  user_prompt: Optional[str] = None,
716
  task: Optional[str] = None,
717
+ max_new_tokens: Optional[int] = None,
718
+ temperature: Optional[float] = None,
719
  **generate_kwargs,
720
  ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
721
  """
722
+ Generate transcription in streaming mode, yielding text chunks as they're generated.
723
+
724
+ Args:
725
+ input_values: Audio input tensor for non-Whisper models
726
+ input_features: Audio input tensor for Whisper models
727
+ system_prompt: System prompt override
728
+ user_prompt: User prompt override
729
+ task: Task type (transcribe, describe, emotion, continue)
730
+ max_new_tokens: Maximum tokens to generate
731
+ temperature: Sampling temperature
732
+ **generate_kwargs: Additional generation parameters
733
+
734
+ Yields:
735
+ StreamChunk: Text chunks as they're generated
736
+ StreamStats: Final statistics (input_tokens, output_tokens)
737
  """
 
 
 
 
 
 
 
 
 
 
738
  audio_inputs = input_values if input_values is not None else input_features
739
  if audio_inputs is None:
740
+ raise ValueError("input_values or input_features must be provided for generation")
741
+
742
+ # Encode audio once and prepare prompt
743
+ audio_embeds = self._encode_audio(audio_inputs)
744
+ batch_size = audio_embeds.shape[0]
745
+ device = audio_embeds.device
746
 
747
+ if batch_size > 1:
748
+ raise ValueError("Streaming generation only supports batch_size=1")
749
 
750
+ if system_prompt is None:
751
+ system_prompt = self.system_prompt
752
+
753
+ if user_prompt is None:
754
+ user_prompt = (
755
+ self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
756
+ or "Transcribe: <audio>"
757
+ )
758
+
759
+ messages = []
760
+ if system_prompt:
761
+ messages.append({"role": "system", "content": system_prompt})
762
+ messages.append({"role": "user", "content": user_prompt})
763
+
764
+ prompt_ids = self.tokenizer.apply_chat_template(
765
+ messages,
766
+ tokenize=True,
767
+ add_generation_prompt=True,
768
+ return_tensors="pt",
769
+ enable_thinking=False,
770
+ ).to(device)
771
+
772
+ if len(prompt_ids.shape) == 1:
773
+ prompt_ids = prompt_ids.unsqueeze(0)
774
+
775
+ if not (prompt_ids == self.audio_token_id).any():
776
+ raise ValueError("Audio token <audio> not found in prompt")
777
+
778
+ num_audio_tokens = audio_embeds.shape[1]
779
+ expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
780
+ inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
781
+ input_token_count = expanded_prompt_ids.shape[1]
782
+
783
+ attention_mask = torch.ones(
784
+ batch_size, input_token_count, dtype=torch.long, device=device
785
+ )
786
+
787
+ # Set up generation parameters
788
+ if max_new_tokens is None:
789
+ max_new_tokens = getattr(self.config, "max_new_tokens", 256)
790
+
791
+ generate_kwargs.setdefault("max_new_tokens", max_new_tokens)
792
+ generate_kwargs.setdefault("use_cache", True)
793
+ generate_kwargs.setdefault(
794
+ "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
795
+ )
796
+ generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
797
+
798
+ if temperature is not None:
799
+ generate_kwargs["temperature"] = temperature
800
+ generate_kwargs.setdefault("do_sample", True)
801
+
802
+ # Set up the streamer
803
+ streamer = TextIteratorStreamer(
804
+ self.tokenizer,
805
+ skip_prompt=True,
806
+ skip_special_tokens=True
807
+ )
808
+
809
+ # Generate in a separate thread
810
  def generation_thread(future: futures.Future):
811
  try:
812
+ result = self.decoder.generate(
813
+ input_ids=expanded_prompt_ids,
814
+ inputs_embeds=inputs_embeds,
815
+ attention_mask=attention_mask,
 
 
 
 
816
  streamer=streamer,
817
  **generate_kwargs,
818
  )
 
820
  except Exception as e:
821
  future.set_exception(e)
822
 
823
+ future: futures.Future[torch.Tensor] = futures.Future()
824
  thread = threading.Thread(target=generation_thread, args=(future,))
825
  thread.start()
826
 
827
+ # Stream the output
828
+ output_text = ""
829
  output_token_count = 0
830
+
831
  try:
832
  for chunk in streamer:
833
+ if chunk:
834
+ output_text += chunk
835
  output_token_count += 1
836
  yield StreamChunk(chunk)
 
 
 
 
 
 
 
 
837
  finally:
838
  # Wait for generation to complete
839
  thread.join()
840
+
841
+ # Check if there was an exception
842
  if future.exception():
843
  raise future.exception()
844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
845
  # Yield final statistics
846
+ yield StreamStats(input_token_count, output_token_count)
847
 
848
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
849
  import shutil