Update models.py
Browse files
models.py
CHANGED
|
@@ -130,6 +130,36 @@ class StudentForAudioClassification(PreTrainedModel):
|
|
| 130 |
self.post_init()
|
| 131 |
|
| 132 |
def forward(self, input_values, **kwargs):
|
| 133 |
-
# input_values: already processed W2V embeddings or raw audio features
|
| 134 |
logits, feat = self.student(input_values)
|
| 135 |
return SequenceClassifierOutput(logits=logits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
self.post_init()
|
| 131 |
|
| 132 |
def forward(self, input_values, **kwargs):
|
|
|
|
| 133 |
logits, feat = self.student(input_values)
|
| 134 |
return SequenceClassifierOutput(logits=logits)
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_pretrained(
|
| 138 |
+
cls,
|
| 139 |
+
pretrained_model_name_or_path,
|
| 140 |
+
*args,
|
| 141 |
+
trust_remote_code=False,
|
| 142 |
+
**kwargs
|
| 143 |
+
):
|
| 144 |
+
# 1) 기본 Config 로드 (trust_remote_code=True 권장)
|
| 145 |
+
config = AutoConfig.from_pretrained(
|
| 146 |
+
pretrained_model_name_or_path,
|
| 147 |
+
trust_remote_code=trust_remote_code,
|
| 148 |
+
**kwargs
|
| 149 |
+
)
|
| 150 |
+
# 2) 빈 래퍼 인스턴스 생성
|
| 151 |
+
model = cls(config)
|
| 152 |
+
|
| 153 |
+
# 3) 원본 체크포인트 state_dict 로드
|
| 154 |
+
# (safetensors라면 safetensors 라이브러리로 읽으세요)
|
| 155 |
+
sd = torch.load(
|
| 156 |
+
f"{pretrained_model_name_or_path}/pytorch_model.bin",
|
| 157 |
+
map_location="cpu",
|
| 158 |
+
weights_only=True
|
| 159 |
+
)
|
| 160 |
+
# 4) 모든 키에 `student.` prefix 추가
|
| 161 |
+
prefixed_sd = {f"student.{k}": v for k, v in sd.items()}
|
| 162 |
+
|
| 163 |
+
# 5) 래퍼 모델에 가중치 로드
|
| 164 |
+
model.load_state_dict(prefixed_sd, strict=True)
|
| 165 |
+
return model
|