File size: 4,262 Bytes
804e0b1 5aae203 c7cd929 916a66e 5aae203 804e0b1 be465ca 5aae203 be465ca 5aae203 804e0b1 5aae203 804e0b1 5aae203 804e0b1 5aae203 804e0b1 5aae203 be465ca 5aae203 7db3d8d 5aae203 916a66e 804e0b1 916a66e be465ca 916a66e 5ae3b5a 71a81a5 5ae3b5a 804e0b1 be465ca 804e0b1 be465ca 5ae3b5a be465ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import os
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput
# ===================== 하드코딩된 설정 ========================
DIM_ECG = 46
DIM_RR = 62
DIM_EDA = 24
DIM_VIDEO = 84
DIM_W2V = 512
TEACHER_FEAT_DIM = 128
NUM_CLASSES = 2
STU_BN_DIM = 256
class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
hidden = TEACHER_FEAT_DIM
dropout = 0.4
self.ecg_encoder = nn.Sequential(
nn.Linear(DIM_ECG, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.rr_encoder = nn.Sequential(
nn.Linear(DIM_RR, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.eda_encoder = nn.Sequential(
nn.Linear(DIM_EDA, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.video_encoder = nn.Sequential(
nn.Linear(DIM_VIDEO, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.classifier = nn.Sequential(
nn.Linear(4 * hidden, hidden),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(hidden, NUM_CLASSES)
)
def forward(self, x_ecg, x_rr, x_eda, x_video):
f_ecg = self.ecg_encoder(x_ecg)
f_rr = self.rr_encoder(x_rr)
f_eda = self.eda_encoder(x_eda)
f_video = self.video_encoder(x_video)
feat = torch.cat([f_ecg, f_rr, f_eda, f_video], dim=1)
logits = self.classifier(feat)
return logits, feat
class StudentNet(nn.Module):
def __init__(self):
super().__init__()
dropout = 0.3
teacher_feat_dim = 4 * TEACHER_FEAT_DIM
self.encoder = nn.Sequential(
nn.Linear(DIM_W2V, STU_BN_DIM),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(STU_BN_DIM, teacher_feat_dim),
nn.ReLU(inplace=True)
)
self.norm = nn.LayerNorm(teacher_feat_dim)
self.classifier = nn.Sequential(
nn.Dropout(p=dropout),
nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(teacher_feat_dim, STU_BN_DIM),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(STU_BN_DIM, NUM_CLASSES)
)
)
def forward(self, x_w2v):
feat = self.encoder(x_w2v)
feat = self.norm(feat)
logits = self.classifier(feat)
return logits, feat
# ==== Transformers Compatibility ==== #
class StressConfig(PretrainedConfig):
model_type = "audio-classification"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.hidden_size = STU_BN_DIM
self.num_labels = NUM_CLASSES
class StudentForAudioClassification(PreTrainedModel):
config_class = StressConfig
def __init__(self, config: StressConfig):
super().__init__(config)
self.student = StudentNet()
self.post_init()
def forward(self, input_values, **kwargs):
logits, feat = self.student(input_values)
return SequenceClassifierOutput(logits=logits)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*model_args,
trust_remote_code=False,
**kwargs
):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs
)
model = cls(config)
# 🟢 [핵심] 경로가 폴더(로컬)면 직접 파일 찾기
if os.path.isdir(pretrained_model_name_or_path):
bin_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
else:
from huggingface_hub import hf_hub_download
bin_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="pytorch_model.bin",
)
sd = torch.load(bin_path, map_location="cpu", weights_only=True)
prefixed_sd = {f"student.{k}": v for k, v in sd.items()}
model.load_state_dict(prefixed_sd, strict=True)
return model
|