mazesmazes commited on
Commit
0e94c99
·
verified ·
1 Parent(s): 0d417e9

Training in progress - step 16000

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. asr_modeling.py +9 -11
  3. asr_processing.py +4 -6
.gitattributes CHANGED
@@ -1,3 +1,4 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
asr_modeling.py CHANGED
@@ -38,7 +38,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
38
  _is_loading_from_pretrained: bool = False
39
  _pretrained_model_path: Optional[str] = None
40
 
41
- TRANSCRIBE_PROMPT = ""
42
 
43
  @classmethod
44
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
@@ -571,11 +571,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
571
  messages: list[dict[str, str]] = []
572
  if system_prompt:
573
  messages.append({"role": "system", "content": system_prompt})
574
- # Audio tokens only (instruction-free)
575
- user_content = audio_placeholder
576
- if self.TRANSCRIBE_PROMPT:
577
- user_content += " " + self.TRANSCRIBE_PROMPT
578
- messages.append({"role": "user", "content": user_content})
579
 
580
  chat_result = self.tokenizer.apply_chat_template(
581
  messages,
@@ -654,11 +653,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
654
  messages: list[dict[str, str]] = []
655
  if system_prompt:
656
  messages.append({"role": "system", "content": system_prompt})
657
- # Audio tokens only (instruction-free)
658
- user_content = audio_placeholder
659
- if self.TRANSCRIBE_PROMPT:
660
- user_content += " " + self.TRANSCRIBE_PROMPT
661
- messages.append({"role": "user", "content": user_content})
662
 
663
  chat_result = self.tokenizer.apply_chat_template(
664
  messages,
 
38
  _is_loading_from_pretrained: bool = False
39
  _pretrained_model_path: Optional[str] = None
40
 
41
+ TRANSCRIBE_PROMPT = "Transcribe speech to text" # Audio tokens come BEFORE this
42
 
43
  @classmethod
44
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
 
571
  messages: list[dict[str, str]] = []
572
  if system_prompt:
573
  messages.append({"role": "system", "content": system_prompt})
574
+ # Audio BEFORE prompt for proper causal attention
575
+ messages.append(
576
+ {"role": "user", "content": audio_placeholder + " " + self.TRANSCRIBE_PROMPT}
577
+ )
 
578
 
579
  chat_result = self.tokenizer.apply_chat_template(
580
  messages,
 
653
  messages: list[dict[str, str]] = []
654
  if system_prompt:
655
  messages.append({"role": "system", "content": system_prompt})
656
+ # Audio BEFORE prompt for proper causal attention
657
+ messages.append(
658
+ {"role": "user", "content": audio_placeholder + " " + self.TRANSCRIBE_PROMPT}
659
+ )
 
660
 
661
  chat_result = self.tokenizer.apply_chat_template(
662
  messages,
asr_processing.py CHANGED
@@ -17,7 +17,7 @@ class ASRProcessor(ProcessorMixin):
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
- TRANSCRIBE_PROMPT = ""
21
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
@@ -89,13 +89,11 @@ class ASRProcessor(ProcessorMixin):
89
  else:
90
  num_audio_tokens = 0
91
 
92
- # Build prompt with audio token placeholders (instruction-free)
93
  if num_audio_tokens > 0:
94
- user_content = self.AUDIO_TOKEN * num_audio_tokens
95
- if self.TRANSCRIBE_PROMPT:
96
- user_content += " " + self.TRANSCRIBE_PROMPT
97
  else:
98
- user_content = self.TRANSCRIBE_PROMPT or ""
99
 
100
  messages = []
101
  if system_prompt:
 
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = "Transcribe speech to text"
21
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
 
89
  else:
90
  num_audio_tokens = 0
91
 
92
+ # Build prompt with audio token placeholders (audio BEFORE prompt)
93
  if num_audio_tokens > 0:
94
+ user_content = self.AUDIO_TOKEN * num_audio_tokens + " " + self.TRANSCRIBE_PROMPT
 
 
95
  else:
96
+ user_content = self.TRANSCRIBE_PROMPT
97
 
98
  messages = []
99
  if system_prompt: