mazesmazes commited on
Commit
02845e7
·
verified ·
1 Parent(s): 450c3e2

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +30 -159
asr_modeling.py CHANGED
@@ -616,6 +616,7 @@ class ASRModel(PreTrainedModel):
616
  system_prompt: Optional[str] = None,
617
  user_prompt: Optional[str] = None,
618
  task: Optional[str] = None,
 
619
  **generate_kwargs,
620
  ) -> Union[
621
  torch.Tensor,
@@ -707,6 +708,10 @@ class ASRModel(PreTrainedModel):
707
  print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
708
  print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
709
 
 
 
 
 
710
  generated_ids = self.decoder.generate(
711
  input_ids=expanded_prompt_ids,
712
  inputs_embeds=inputs_embeds,
@@ -724,157 +729,11 @@ class ASRModel(PreTrainedModel):
724
  system_prompt: Optional[str] = None,
725
  user_prompt: Optional[str] = None,
726
  task: Optional[str] = None,
727
- max_new_tokens: Optional[int] = None,
728
- temperature: Optional[float] = None,
729
  **generate_kwargs,
730
  ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
731
  """
732
- Generate transcription in streaming mode, yielding text chunks as they're generated.
733
-
734
- Args:
735
- input_values: Audio input tensor for non-Whisper models
736
- input_features: Audio input tensor for Whisper models
737
- system_prompt: System prompt override
738
- user_prompt: User prompt override
739
- task: Task type (transcribe, describe, emotion, continue)
740
- max_new_tokens: Maximum tokens to generate
741
- temperature: Sampling temperature
742
- **generate_kwargs: Additional generation parameters
743
-
744
- Yields:
745
- StreamChunk: Text chunks as they're generated
746
- StreamStats: Final statistics (input_tokens, output_tokens)
747
  """
748
- audio_inputs = input_values if input_values is not None else input_features
749
- if audio_inputs is None:
750
- raise ValueError("input_values or input_features must be provided for generation")
751
-
752
- # Debug: Check audio inputs
753
- import sys
754
- print(f"DEBUG generate_stream: audio_inputs shape={audio_inputs.shape if audio_inputs is not None else None}", file=sys.stderr)
755
- print(f"DEBUG generate_stream: audio_inputs type={type(audio_inputs)}", file=sys.stderr)
756
-
757
- # Encode audio once and prepare prompt
758
- audio_embeds = self._encode_audio(audio_inputs)
759
- batch_size = audio_embeds.shape[0]
760
- device = audio_embeds.device
761
-
762
- if batch_size > 1:
763
- raise ValueError("Streaming generation only supports batch_size=1")
764
-
765
- if system_prompt is None:
766
- system_prompt = self.system_prompt
767
-
768
- if user_prompt is None:
769
- user_prompt = (
770
- self.TASK_PROMPTS.get(task, self.config.user_prompt or "Transcribe: <audio>")
771
- or "Transcribe: <audio>"
772
- )
773
-
774
- messages = []
775
- if system_prompt:
776
- messages.append({"role": "system", "content": system_prompt})
777
- messages.append({"role": "user", "content": user_prompt})
778
-
779
- prompt_ids = self.tokenizer.apply_chat_template(
780
- messages,
781
- tokenize=True,
782
- add_generation_prompt=True,
783
- return_tensors="pt",
784
- enable_thinking=False,
785
- ).to(device)
786
-
787
- if len(prompt_ids.shape) == 1:
788
- prompt_ids = prompt_ids.unsqueeze(0)
789
-
790
- if not (prompt_ids == self.audio_token_id).any():
791
- raise ValueError("Audio token <audio> not found in prompt")
792
-
793
- num_audio_tokens = audio_embeds.shape[1]
794
- expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
795
- inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
796
- input_token_count = expanded_prompt_ids.shape[1]
797
-
798
- attention_mask = torch.ones(
799
- batch_size, input_token_count, dtype=torch.long, device=device
800
- )
801
-
802
- # Set up generation parameters from config (same as non-streaming generate)
803
- config_params = [
804
- "max_new_tokens",
805
- "min_new_tokens",
806
- "num_beams",
807
- "do_sample",
808
- "temperature",
809
- "top_k",
810
- "top_p",
811
- "repetition_penalty",
812
- "length_penalty",
813
- "no_repeat_ngram_size",
814
- "early_stopping",
815
- ]
816
- for param in config_params:
817
- if hasattr(self.config, param) and getattr(self.config, param) is not None:
818
- generate_kwargs.setdefault(param, getattr(self.config, param))
819
-
820
- # Override with explicit parameters if provided
821
- if max_new_tokens is not None:
822
- generate_kwargs["max_new_tokens"] = max_new_tokens
823
-
824
- if temperature is not None:
825
- generate_kwargs["temperature"] = temperature
826
- generate_kwargs["do_sample"] = True
827
-
828
- generate_kwargs.setdefault("use_cache", True)
829
- generate_kwargs.setdefault(
830
- "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
831
- )
832
- generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
833
-
834
- # Debug: Check if audio embeds are in inputs_embeds
835
- import sys
836
- print(f"DEBUG generate_stream: task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
837
- print(f"DEBUG generate_stream: inputs_embeds shape={inputs_embeds.shape}", file=sys.stderr)
838
- print(f"DEBUG generate_stream: expanded_prompt_ids shape={expanded_prompt_ids.shape}", file=sys.stderr)
839
- print(f"DEBUG generate_stream: audio_embeds shape={audio_embeds.shape}", file=sys.stderr)
840
- print(f"DEBUG generate_stream: num_audio_tokens={num_audio_tokens}", file=sys.stderr)
841
- print(f"DEBUG generate_stream: generate_kwargs={generate_kwargs}", file=sys.stderr)
842
-
843
- # Debug: Check devices and values
844
- print(f"DEBUG: inputs_embeds device={inputs_embeds.device}", file=sys.stderr)
845
- print(f"DEBUG: expanded_prompt_ids device={expanded_prompt_ids.device}", file=sys.stderr)
846
- print(f"DEBUG: attention_mask device={attention_mask.device}", file=sys.stderr)
847
- print(f"DEBUG: decoder device={next(self.decoder.parameters()).device}", file=sys.stderr)
848
-
849
- # Check if audio embeddings are non-zero
850
- audio_mask = (expanded_prompt_ids == self.audio_token_id)
851
- print(f"DEBUG: audio_mask sum={audio_mask.sum().item()} (should be {num_audio_tokens})", file=sys.stderr)
852
-
853
- # Check a sample of the embeddings where audio should be
854
- audio_positions = torch.where(audio_mask[0])[0]
855
- if len(audio_positions) > 0:
856
- sample_pos = audio_positions[0].item()
857
- print(f"DEBUG: Sample audio embed at pos {sample_pos}: mean={inputs_embeds[0, sample_pos].mean().item():.4f}, std={inputs_embeds[0, sample_pos].std().item():.4f}", file=sys.stderr)
858
-
859
- # Test: Try without threading first to see if that's the issue
860
- print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
861
- print(f"DEBUG: input_token_count (prompt length) = {input_token_count}", file=sys.stderr)
862
-
863
- test_output = self.decoder.generate(
864
- input_ids=expanded_prompt_ids,
865
- inputs_embeds=inputs_embeds,
866
- attention_mask=attention_mask,
867
- max_new_tokens=10, # Just generate a few tokens to test
868
- **{k: v for k, v in generate_kwargs.items() if k != 'max_new_tokens'}
869
- )
870
-
871
- # Debug the output
872
- full_text = self.tokenizer.decode(test_output[0], skip_special_tokens=True)
873
- print(f"DEBUG: Full output text: {full_text}", file=sys.stderr)
874
-
875
- test_text = self.tokenizer.decode(test_output[0, input_token_count:], skip_special_tokens=True)
876
- print(f"DEBUG: Non-threaded test output (after removing prompt): {test_text}", file=sys.stderr)
877
-
878
  # Set up the streamer
879
  streamer = TextIteratorStreamer(
880
  self.tokenizer,
@@ -882,13 +741,26 @@ class ASRModel(PreTrainedModel):
882
  skip_special_tokens=True
883
  )
884
 
885
- # Generate in a separate thread
 
 
 
 
 
 
 
 
 
 
886
  def generation_thread(future: futures.Future):
887
  try:
888
- result = self.decoder.generate(
889
- input_ids=expanded_prompt_ids,
890
- inputs_embeds=inputs_embeds,
891
- attention_mask=attention_mask,
 
 
 
892
  streamer=streamer,
893
  **generate_kwargs,
894
  )
@@ -896,30 +768,29 @@ class ASRModel(PreTrainedModel):
896
  except Exception as e:
897
  future.set_exception(e)
898
 
899
- future: futures.Future[torch.Tensor] = futures.Future()
900
  thread = threading.Thread(target=generation_thread, args=(future,))
901
  thread.start()
902
 
903
  # Stream the output
904
- output_text = ""
905
  output_token_count = 0
906
-
907
  try:
908
  for chunk in streamer:
909
  if chunk:
910
- output_text += chunk
911
  output_token_count += 1
912
  yield StreamChunk(chunk)
913
  finally:
914
  # Wait for generation to complete
915
  thread.join()
916
-
917
- # Check if there was an exception
918
  if future.exception():
919
  raise future.exception()
920
 
 
 
 
 
921
  # Yield final statistics
922
- yield StreamStats(input_token_count, output_token_count)
923
 
924
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
925
  import shutil
 
616
  system_prompt: Optional[str] = None,
617
  user_prompt: Optional[str] = None,
618
  task: Optional[str] = None,
619
+ streamer: Optional[TextIteratorStreamer] = None,
620
  **generate_kwargs,
621
  ) -> Union[
622
  torch.Tensor,
 
708
  print(f"DEBUG generate (non-streaming): task={task}, system_prompt={system_prompt}, user_prompt={user_prompt}", file=sys.stderr)
709
  print(f"DEBUG generate (non-streaming): generate_kwargs={generate_kwargs}", file=sys.stderr)
710
 
711
+ # Add streamer if provided
712
+ if streamer is not None:
713
+ generate_kwargs["streamer"] = streamer
714
+
715
  generated_ids = self.decoder.generate(
716
  input_ids=expanded_prompt_ids,
717
  inputs_embeds=inputs_embeds,
 
729
  system_prompt: Optional[str] = None,
730
  user_prompt: Optional[str] = None,
731
  task: Optional[str] = None,
 
 
732
  **generate_kwargs,
733
  ) -> Generator[Union[StreamChunk, StreamStats], None, None]:
734
  """
735
+ Stream generation by using the working generate() method with a TextIteratorStreamer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  # Set up the streamer
738
  streamer = TextIteratorStreamer(
739
  self.tokenizer,
 
741
  skip_special_tokens=True
742
  )
743
 
744
+ # Count prompt length for stats
745
+ # We need to encode just to get the prompt length
746
+ audio_inputs = input_values if input_values is not None else input_features
747
+ if audio_inputs is None:
748
+ raise ValueError("input_values or input_features must be provided")
749
+
750
+ # Simple way to get prompt length - just count audio tokens
751
+ import threading
752
+ from concurrent import futures
753
+
754
+ # Run generation in a thread with streamer
755
  def generation_thread(future: futures.Future):
756
  try:
757
+ # Just call the working generate method with the streamer
758
+ result = self.generate(
759
+ input_values=input_values,
760
+ input_features=input_features,
761
+ system_prompt=system_prompt,
762
+ user_prompt=user_prompt,
763
+ task=task,
764
  streamer=streamer,
765
  **generate_kwargs,
766
  )
 
768
  except Exception as e:
769
  future.set_exception(e)
770
 
771
+ future: futures.Future = futures.Future()
772
  thread = threading.Thread(target=generation_thread, args=(future,))
773
  thread.start()
774
 
775
  # Stream the output
 
776
  output_token_count = 0
 
777
  try:
778
  for chunk in streamer:
779
  if chunk:
 
780
  output_token_count += 1
781
  yield StreamChunk(chunk)
782
  finally:
783
  # Wait for generation to complete
784
  thread.join()
 
 
785
  if future.exception():
786
  raise future.exception()
787
 
788
+ # For stats, estimate input tokens (we can't easily get exact count without duplicating work)
789
+ # Rough estimate: prompt is about 20 tokens + 750 audio tokens
790
+ estimated_input_tokens = 770
791
+
792
  # Yield final statistics
793
+ yield StreamStats(estimated_input_tokens, output_token_count)
794
 
795
  def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
796
  import shutil