mazesmazes commited on
Commit
e4d0b19
·
verified ·
1 Parent(s): 4a81e2e

Training in progress - step 500

Browse files
Files changed (2) hide show
  1. asr_modeling.py +18 -0
  2. config.json +6 -0
asr_modeling.py CHANGED
@@ -38,6 +38,8 @@ class ASRModel(PreTrainedModel, GenerationMixin):
38
  @classmethod
39
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
40
  """Load model from pretrained, handling device placement correctly."""
 
 
41
  from safetensors.torch import load_file
42
  from transformers.utils.hub import cached_file
43
 
@@ -72,6 +74,22 @@ class ASRModel(PreTrainedModel, GenerationMixin):
72
  state_dict = load_file(model_file)
73
  model.load_state_dict(state_dict, strict=False)
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return model
76
  finally:
77
  cls._is_loading_from_pretrained = False
 
38
  @classmethod
39
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
40
  """Load model from pretrained, handling device placement correctly."""
41
+ from pathlib import Path
42
+
43
  from safetensors.torch import load_file
44
  from transformers.utils.hub import cached_file
45
 
 
74
  state_dict = load_file(model_file)
75
  model.load_state_dict(state_dict, strict=False)
76
 
77
+ # Load LoRA adapter if present
78
+ adapter_config = cached_file(
79
+ pretrained_model_name_or_path,
80
+ "adapter_config.json",
81
+ _raise_exceptions_for_missing_entries=False,
82
+ **cache_kwargs,
83
+ )
84
+ if adapter_config is not None:
85
+ from peft import PeftModel
86
+
87
+ # Get adapter directory (parent of adapter_config.json)
88
+ adapter_path = Path(adapter_config).parent
89
+ model.language_model = PeftModel.from_pretrained(
90
+ model.language_model, adapter_path, is_trainable=False
91
+ )
92
+
93
  return model
94
  finally:
95
  cls._is_loading_from_pretrained = False
config.json CHANGED
@@ -161,6 +161,10 @@
161
  "label_smoothing": 0.0,
162
  "length_penalty": 1.0,
163
  "llm_dim": 2048,
 
 
 
 
164
  "max_new_tokens": 96,
165
  "model_dtype": "bfloat16",
166
  "model_type": "asr_model",
@@ -169,6 +173,7 @@
169
  "num_experts": 4,
170
  "num_experts_per_tok": 2,
171
  "pipeline_tag": "automatic-speech-recognition",
 
172
  "projector_dropout": 0.0,
173
  "projector_hidden_dim": null,
174
  "projector_init_std": 0.02,
@@ -249,6 +254,7 @@
249
  "text_model_id": "Qwen/Qwen3-1.7B",
250
  "transformers_version": "5.0.0.dev0",
251
  "use_cache": false,
 
252
  "use_specaugment": true,
253
  "user_prompt": "Please transcribe this English audio into text: <audio>",
254
  "vocab_size": 151670
 
161
  "label_smoothing": 0.0,
162
  "length_penalty": 1.0,
163
  "llm_dim": 2048,
164
+ "lora_alpha": 32,
165
+ "lora_dropout": 0.0,
166
+ "lora_r": 32,
167
+ "lora_target_modules": "all-linear",
168
  "max_new_tokens": 96,
169
  "model_dtype": "bfloat16",
170
  "model_type": "asr_model",
 
173
  "num_experts": 4,
174
  "num_experts_per_tok": 2,
175
  "pipeline_tag": "automatic-speech-recognition",
176
+ "pretrained_model_path": "mazesmazes/tiny-audio-glm",
177
  "projector_dropout": 0.0,
178
  "projector_hidden_dim": null,
179
  "projector_init_std": 0.02,
 
254
  "text_model_id": "Qwen/Qwen3-1.7B",
255
  "transformers_version": "5.0.0.dev0",
256
  "use_cache": false,
257
+ "use_lora": true,
258
  "use_specaugment": true,
259
  "user_prompt": "Please transcribe this English audio into text: <audio>",
260
  "vocab_size": 151670