Update custom model files, README, and requirements
Browse files- asr_modeling.py +11 -3
asr_modeling.py
CHANGED
|
@@ -130,7 +130,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 130 |
encoder_kwargs = {
|
| 131 |
"attn_implementation": config.attn_implementation,
|
| 132 |
"low_cpu_mem_usage": True,
|
| 133 |
-
"
|
| 134 |
}
|
| 135 |
|
| 136 |
if "whisper" in config.audio_model_id.lower():
|
|
@@ -143,13 +143,20 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 143 |
# GLM-ASR models use audio_tower as the encoder
|
| 144 |
# Requires transformers >= 5.x or installed from source
|
| 145 |
from transformers import AutoModelForSeq2SeqLM
|
|
|
|
| 146 |
|
| 147 |
full_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 148 |
config.audio_model_id, trust_remote_code=True, **encoder_kwargs
|
| 149 |
)
|
| 150 |
# GLM stores encoder at audio_tower (GlmAsrEncoder)
|
| 151 |
encoder = full_model.audio_tower
|
|
|
|
|
|
|
|
|
|
| 152 |
del full_model
|
|
|
|
|
|
|
|
|
|
| 153 |
else:
|
| 154 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 155 |
|
|
@@ -427,12 +434,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 427 |
messages.append({"role": "system", "content": system_prompt})
|
| 428 |
messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
|
| 429 |
|
| 430 |
-
|
| 431 |
messages,
|
| 432 |
tokenize=True,
|
| 433 |
add_generation_prompt=True,
|
| 434 |
return_tensors="pt",
|
| 435 |
-
)
|
|
|
|
| 436 |
|
| 437 |
if input_ids.dim() == 1:
|
| 438 |
input_ids = input_ids.unsqueeze(0)
|
|
|
|
| 130 |
encoder_kwargs = {
|
| 131 |
"attn_implementation": config.attn_implementation,
|
| 132 |
"low_cpu_mem_usage": True,
|
| 133 |
+
"torch_dtype": dtype,
|
| 134 |
}
|
| 135 |
|
| 136 |
if "whisper" in config.audio_model_id.lower():
|
|
|
|
| 143 |
# GLM-ASR models use audio_tower as the encoder
|
| 144 |
# Requires transformers >= 5.x or installed from source
|
| 145 |
from transformers import AutoModelForSeq2SeqLM
|
| 146 |
+
import gc
|
| 147 |
|
| 148 |
full_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 149 |
config.audio_model_id, trust_remote_code=True, **encoder_kwargs
|
| 150 |
)
|
| 151 |
# GLM stores encoder at audio_tower (GlmAsrEncoder)
|
| 152 |
encoder = full_model.audio_tower
|
| 153 |
+
# Clear references to free VRAM from the LLM decoder
|
| 154 |
+
full_model.language_model = None
|
| 155 |
+
full_model.multi_modal_projector = None
|
| 156 |
del full_model
|
| 157 |
+
gc.collect()
|
| 158 |
+
if torch.cuda.is_available():
|
| 159 |
+
torch.cuda.empty_cache()
|
| 160 |
else:
|
| 161 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 162 |
|
|
|
|
| 434 |
messages.append({"role": "system", "content": system_prompt})
|
| 435 |
messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
|
| 436 |
|
| 437 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 438 |
messages,
|
| 439 |
tokenize=True,
|
| 440 |
add_generation_prompt=True,
|
| 441 |
return_tensors="pt",
|
| 442 |
+
)
|
| 443 |
+
input_ids = chat_result.input_ids.to(device)
|
| 444 |
|
| 445 |
if input_ids.dim() == 1:
|
| 446 |
input_ids = input_ids.unsqueeze(0)
|