forwarder1121 commited on
Commit
5ae3b5a
·
verified ·
1 Parent(s): 54d3ebf

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +31 -1
models.py CHANGED
@@ -130,6 +130,36 @@ class StudentForAudioClassification(PreTrainedModel):
130
  self.post_init()
131
 
132
  def forward(self, input_values, **kwargs):
133
- # input_values: already processed W2V embeddings or raw audio features
134
  logits, feat = self.student(input_values)
135
  return SequenceClassifierOutput(logits=logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  self.post_init()
131
 
132
  def forward(self, input_values, **kwargs):
 
133
  logits, feat = self.student(input_values)
134
  return SequenceClassifierOutput(logits=logits)
135
+
136
+ @classmethod
137
+ def from_pretrained(
138
+ cls,
139
+ pretrained_model_name_or_path,
140
+ *args,
141
+ trust_remote_code=False,
142
+ **kwargs
143
+ ):
144
+ # 1) 기본 Config 로드 (trust_remote_code=True 권장)
145
+ config = AutoConfig.from_pretrained(
146
+ pretrained_model_name_or_path,
147
+ trust_remote_code=trust_remote_code,
148
+ **kwargs
149
+ )
150
+ # 2) 빈 래퍼 인스턴스 생성
151
+ model = cls(config)
152
+
153
+ # 3) 원본 체크포인트 state_dict 로드
154
+ # (safetensors라면 safetensors 라이브러리로 읽으세요)
155
+ sd = torch.load(
156
+ f"{pretrained_model_name_or_path}/pytorch_model.bin",
157
+ map_location="cpu",
158
+ weights_only=True
159
+ )
160
+ # 4) 모든 키에 `student.` prefix 추가
161
+ prefixed_sd = {f"student.{k}": v for k, v in sd.items()}
162
+
163
+ # 5) 래퍼 모델에 가중치 로드
164
+ model.load_state_dict(prefixed_sd, strict=True)
165
+ return model