mazesmazes commited on
Commit
4d1a14c
·
verified ·
1 Parent(s): 4857a95

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -1
  2. asr_modeling.py +11 -9
  3. asr_processing.py +6 -4
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
asr_modeling.py CHANGED
@@ -38,7 +38,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
38
  _is_loading_from_pretrained: bool = False
39
  _pretrained_model_path: Optional[str] = None
40
 
41
- TRANSCRIBE_PROMPT = "Transcribe speech to text" # Audio tokens come BEFORE this
42
 
43
  @classmethod
44
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
@@ -571,10 +571,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
571
  messages: list[dict[str, str]] = []
572
  if system_prompt:
573
  messages.append({"role": "system", "content": system_prompt})
574
- # Audio BEFORE prompt for proper causal attention
575
- messages.append(
576
- {"role": "user", "content": audio_placeholder + " " + self.TRANSCRIBE_PROMPT}
577
- )
 
578
 
579
  chat_result = self.tokenizer.apply_chat_template(
580
  messages,
@@ -653,10 +654,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
653
  messages: list[dict[str, str]] = []
654
  if system_prompt:
655
  messages.append({"role": "system", "content": system_prompt})
656
- # Audio BEFORE prompt for proper causal attention
657
- messages.append(
658
- {"role": "user", "content": audio_placeholder + " " + self.TRANSCRIBE_PROMPT}
659
- )
 
660
 
661
  chat_result = self.tokenizer.apply_chat_template(
662
  messages,
 
38
  _is_loading_from_pretrained: bool = False
39
  _pretrained_model_path: Optional[str] = None
40
 
41
+ TRANSCRIBE_PROMPT = ""
42
 
43
  @classmethod
44
  def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
 
571
  messages: list[dict[str, str]] = []
572
  if system_prompt:
573
  messages.append({"role": "system", "content": system_prompt})
574
+ # Audio tokens only (instruction-free)
575
+ user_content = audio_placeholder
576
+ if self.TRANSCRIBE_PROMPT:
577
+ user_content += " " + self.TRANSCRIBE_PROMPT
578
+ messages.append({"role": "user", "content": user_content})
579
 
580
  chat_result = self.tokenizer.apply_chat_template(
581
  messages,
 
654
  messages: list[dict[str, str]] = []
655
  if system_prompt:
656
  messages.append({"role": "system", "content": system_prompt})
657
+ # Audio tokens only (instruction-free)
658
+ user_content = audio_placeholder
659
+ if self.TRANSCRIBE_PROMPT:
660
+ user_content += " " + self.TRANSCRIBE_PROMPT
661
+ messages.append({"role": "user", "content": user_content})
662
 
663
  chat_result = self.tokenizer.apply_chat_template(
664
  messages,
asr_processing.py CHANGED
@@ -17,7 +17,7 @@ class ASRProcessor(ProcessorMixin):
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
- TRANSCRIBE_PROMPT = "Transcribe speech to text"
21
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
@@ -89,11 +89,13 @@ class ASRProcessor(ProcessorMixin):
89
  else:
90
  num_audio_tokens = 0
91
 
92
- # Build prompt with audio token placeholders (audio BEFORE prompt)
93
  if num_audio_tokens > 0:
94
- user_content = self.AUDIO_TOKEN * num_audio_tokens + " " + self.TRANSCRIBE_PROMPT
 
 
95
  else:
96
- user_content = self.TRANSCRIBE_PROMPT
97
 
98
  messages = []
99
  if system_prompt:
 
17
  feature_extractor_class = "AutoFeatureExtractor"
18
  tokenizer_class = "AutoTokenizer"
19
  AUDIO_TOKEN = "<audio>"
20
+ TRANSCRIBE_PROMPT = ""
21
  # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
22
  DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
23
 
 
89
  else:
90
  num_audio_tokens = 0
91
 
92
+ # Build prompt with audio token placeholders (instruction-free)
93
  if num_audio_tokens > 0:
94
+ user_content = self.AUDIO_TOKEN * num_audio_tokens
95
+ if self.TRANSCRIBE_PROMPT:
96
+ user_content += " " + self.TRANSCRIBE_PROMPT
97
  else:
98
+ user_content = self.TRANSCRIBE_PROMPT or ""
99
 
100
  messages = []
101
  if system_prompt: