Update custom model files, README, and requirements
Browse files- asr_modeling.py +8 -1
asr_modeling.py
CHANGED
|
@@ -858,6 +858,8 @@ class ASRModel(PreTrainedModel):
|
|
| 858 |
|
| 859 |
# Test: Try without threading first to see if that's the issue
|
| 860 |
print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
|
|
|
|
|
|
|
| 861 |
test_output = self.decoder.generate(
|
| 862 |
input_ids=expanded_prompt_ids,
|
| 863 |
inputs_embeds=inputs_embeds,
|
|
@@ -865,8 +867,13 @@ class ASRModel(PreTrainedModel):
|
|
| 865 |
max_new_tokens=10, # Just generate a few tokens to test
|
| 866 |
**{k: v for k, v in generate_kwargs.items() if k != 'max_new_tokens'}
|
| 867 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
test_text = self.tokenizer.decode(test_output[0, input_token_count:], skip_special_tokens=True)
|
| 869 |
-
print(f"DEBUG: Non-threaded test output: {test_text}", file=sys.stderr)
|
| 870 |
|
| 871 |
# Set up the streamer
|
| 872 |
streamer = TextIteratorStreamer(
|
|
|
|
| 858 |
|
| 859 |
# Test: Try without threading first to see if that's the issue
|
| 860 |
print(f"DEBUG: Testing non-threaded generation first", file=sys.stderr)
|
| 861 |
+
print(f"DEBUG: input_token_count (prompt length) = {input_token_count}", file=sys.stderr)
|
| 862 |
+
|
| 863 |
test_output = self.decoder.generate(
|
| 864 |
input_ids=expanded_prompt_ids,
|
| 865 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 867 |
max_new_tokens=10, # Just generate a few tokens to test
|
| 868 |
**{k: v for k, v in generate_kwargs.items() if k != 'max_new_tokens'}
|
| 869 |
)
|
| 870 |
+
|
| 871 |
+
# Debug the output
|
| 872 |
+
full_text = self.tokenizer.decode(test_output[0], skip_special_tokens=True)
|
| 873 |
+
print(f"DEBUG: Full output text: {full_text}", file=sys.stderr)
|
| 874 |
+
|
| 875 |
test_text = self.tokenizer.decode(test_output[0, input_token_count:], skip_special_tokens=True)
|
| 876 |
+
print(f"DEBUG: Non-threaded test output (after removing prompt): {test_text}", file=sys.stderr)
|
| 877 |
|
| 878 |
# Set up the streamer
|
| 879 |
streamer = TextIteratorStreamer(
|