forwarder1121's picture
Update models.py
804e0b1 verified
import os
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput
# ===================== ํ•˜๋“œ์ฝ”๋”ฉ๋œ ์„ค์ • ========================
DIM_ECG = 46
DIM_RR = 62
DIM_EDA = 24
DIM_VIDEO = 84
DIM_W2V = 512
TEACHER_FEAT_DIM = 128
NUM_CLASSES = 2
STU_BN_DIM = 256
class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
hidden = TEACHER_FEAT_DIM
dropout = 0.4
self.ecg_encoder = nn.Sequential(
nn.Linear(DIM_ECG, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.rr_encoder = nn.Sequential(
nn.Linear(DIM_RR, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.eda_encoder = nn.Sequential(
nn.Linear(DIM_EDA, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.video_encoder = nn.Sequential(
nn.Linear(DIM_VIDEO, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout)
)
self.classifier = nn.Sequential(
nn.Linear(4 * hidden, hidden),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(hidden, NUM_CLASSES)
)
def forward(self, x_ecg, x_rr, x_eda, x_video):
f_ecg = self.ecg_encoder(x_ecg)
f_rr = self.rr_encoder(x_rr)
f_eda = self.eda_encoder(x_eda)
f_video = self.video_encoder(x_video)
feat = torch.cat([f_ecg, f_rr, f_eda, f_video], dim=1)
logits = self.classifier(feat)
return logits, feat
class StudentNet(nn.Module):
def __init__(self):
super().__init__()
dropout = 0.3
teacher_feat_dim = 4 * TEACHER_FEAT_DIM
self.encoder = nn.Sequential(
nn.Linear(DIM_W2V, STU_BN_DIM),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(STU_BN_DIM, teacher_feat_dim),
nn.ReLU(inplace=True)
)
self.norm = nn.LayerNorm(teacher_feat_dim)
self.classifier = nn.Sequential(
nn.Dropout(p=dropout),
nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(teacher_feat_dim, STU_BN_DIM),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(STU_BN_DIM, NUM_CLASSES)
)
)
def forward(self, x_w2v):
feat = self.encoder(x_w2v)
feat = self.norm(feat)
logits = self.classifier(feat)
return logits, feat
# ==== Transformers Compatibility ==== #
class StressConfig(PretrainedConfig):
model_type = "audio-classification"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.hidden_size = STU_BN_DIM
self.num_labels = NUM_CLASSES
class StudentForAudioClassification(PreTrainedModel):
config_class = StressConfig
def __init__(self, config: StressConfig):
super().__init__(config)
self.student = StudentNet()
self.post_init()
def forward(self, input_values, **kwargs):
logits, feat = self.student(input_values)
return SequenceClassifierOutput(logits=logits)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*model_args,
trust_remote_code=False,
**kwargs
):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs
)
model = cls(config)
# ๐ŸŸข [ํ•ต์‹ฌ] ๊ฒฝ๋กœ๊ฐ€ ํด๋”(๋กœ์ปฌ)๋ฉด ์ง์ ‘ ํŒŒ์ผ ์ฐพ๊ธฐ
if os.path.isdir(pretrained_model_name_or_path):
bin_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
else:
from huggingface_hub import hf_hub_download
bin_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="pytorch_model.bin",
)
sd = torch.load(bin_path, map_location="cpu", weights_only=True)
prefixed_sd = {f"student.{k}": v for k, v in sd.items()}
model.load_state_dict(prefixed_sd, strict=True)
return model