mazesmazes commited on
Commit
42db72a
·
verified ·
1 Parent(s): 89c5fc3

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +11 -3
asr_modeling.py CHANGED
@@ -130,7 +130,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
130
  encoder_kwargs = {
131
  "attn_implementation": config.attn_implementation,
132
  "low_cpu_mem_usage": True,
133
- "dtype": dtype,
134
  }
135
 
136
  if "whisper" in config.audio_model_id.lower():
@@ -143,13 +143,20 @@ class ASRModel(PreTrainedModel, GenerationMixin):
143
  # GLM-ASR models use audio_tower as the encoder
144
  # Requires transformers >= 5.x or installed from source
145
  from transformers import AutoModelForSeq2SeqLM
 
146
 
147
  full_model = AutoModelForSeq2SeqLM.from_pretrained(
148
  config.audio_model_id, trust_remote_code=True, **encoder_kwargs
149
  )
150
  # GLM stores encoder at audio_tower (GlmAsrEncoder)
151
  encoder = full_model.audio_tower
 
 
 
152
  del full_model
 
 
 
153
  else:
154
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
155
 
@@ -427,12 +434,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
427
  messages.append({"role": "system", "content": system_prompt})
428
  messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
429
 
430
- input_ids = self.tokenizer.apply_chat_template(
431
  messages,
432
  tokenize=True,
433
  add_generation_prompt=True,
434
  return_tensors="pt",
435
- ).to(device)
 
436
 
437
  if input_ids.dim() == 1:
438
  input_ids = input_ids.unsqueeze(0)
 
130
  encoder_kwargs = {
131
  "attn_implementation": config.attn_implementation,
132
  "low_cpu_mem_usage": True,
133
+ "torch_dtype": dtype,
134
  }
135
 
136
  if "whisper" in config.audio_model_id.lower():
 
143
  # GLM-ASR models use audio_tower as the encoder
144
  # Requires transformers >= 5.x or installed from source
145
  from transformers import AutoModelForSeq2SeqLM
146
+ import gc
147
 
148
  full_model = AutoModelForSeq2SeqLM.from_pretrained(
149
  config.audio_model_id, trust_remote_code=True, **encoder_kwargs
150
  )
151
  # GLM stores encoder at audio_tower (GlmAsrEncoder)
152
  encoder = full_model.audio_tower
153
+ # Clear references to free VRAM from the LLM decoder
154
+ full_model.language_model = None
155
+ full_model.multi_modal_projector = None
156
  del full_model
157
+ gc.collect()
158
+ if torch.cuda.is_available():
159
+ torch.cuda.empty_cache()
160
  else:
161
  encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
162
 
 
434
  messages.append({"role": "system", "content": system_prompt})
435
  messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
436
 
437
+ chat_result = self.tokenizer.apply_chat_template(
438
  messages,
439
  tokenize=True,
440
  add_generation_prompt=True,
441
  return_tensors="pt",
442
+ )
443
+ input_ids = chat_result.input_ids.to(device)
444
 
445
  if input_ids.dim() == 1:
446
  input_ids = input_ids.unsqueeze(0)