|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
|
|
|
DIM_ECG = 46 |
|
|
DIM_RR = 62 |
|
|
DIM_EDA = 24 |
|
|
DIM_VIDEO = 84 |
|
|
DIM_W2V = 512 |
|
|
|
|
|
TEACHER_FEAT_DIM = 128 |
|
|
NUM_CLASSES = 2 |
|
|
STU_BN_DIM = 256 |
|
|
|
|
|
class TeacherNet(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
hidden = TEACHER_FEAT_DIM |
|
|
dropout = 0.4 |
|
|
self.ecg_encoder = nn.Sequential( |
|
|
nn.Linear(DIM_ECG, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout) |
|
|
) |
|
|
self.rr_encoder = nn.Sequential( |
|
|
nn.Linear(DIM_RR, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout) |
|
|
) |
|
|
self.eda_encoder = nn.Sequential( |
|
|
nn.Linear(DIM_EDA, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout) |
|
|
) |
|
|
self.video_encoder = nn.Sequential( |
|
|
nn.Linear(DIM_VIDEO, hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout) |
|
|
) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(4 * hidden, hidden), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(p=dropout), |
|
|
nn.Linear(hidden, NUM_CLASSES) |
|
|
) |
|
|
|
|
|
def forward(self, x_ecg, x_rr, x_eda, x_video): |
|
|
f_ecg = self.ecg_encoder(x_ecg) |
|
|
f_rr = self.rr_encoder(x_rr) |
|
|
f_eda = self.eda_encoder(x_eda) |
|
|
f_video = self.video_encoder(x_video) |
|
|
feat = torch.cat([f_ecg, f_rr, f_eda, f_video], dim=1) |
|
|
logits = self.classifier(feat) |
|
|
return logits, feat |
|
|
|
|
|
class StudentNet(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
dropout = 0.3 |
|
|
teacher_feat_dim = 4 * TEACHER_FEAT_DIM |
|
|
self.encoder = nn.Sequential( |
|
|
nn.Linear(DIM_W2V, STU_BN_DIM), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(p=dropout), |
|
|
nn.Linear(STU_BN_DIM, teacher_feat_dim), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.norm = nn.LayerNorm(teacher_feat_dim) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(p=dropout), |
|
|
nn.Sequential( |
|
|
nn.Dropout(p=dropout), |
|
|
nn.Linear(teacher_feat_dim, STU_BN_DIM), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(p=dropout), |
|
|
nn.Linear(STU_BN_DIM, NUM_CLASSES) |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x_w2v): |
|
|
feat = self.encoder(x_w2v) |
|
|
feat = self.norm(feat) |
|
|
logits = self.classifier(feat) |
|
|
return logits, feat |
|
|
|
|
|
|
|
|
|
|
|
class StressConfig(PretrainedConfig): |
|
|
model_type = "audio-classification" |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.hidden_size = STU_BN_DIM |
|
|
self.num_labels = NUM_CLASSES |
|
|
|
|
|
class StudentForAudioClassification(PreTrainedModel): |
|
|
config_class = StressConfig |
|
|
|
|
|
def __init__(self, config: StressConfig): |
|
|
super().__init__(config) |
|
|
self.student = StudentNet() |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_values, **kwargs): |
|
|
logits, feat = self.student(input_values) |
|
|
return SequenceClassifierOutput(logits=logits) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path, |
|
|
*model_args, |
|
|
trust_remote_code=False, |
|
|
**kwargs |
|
|
): |
|
|
config = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code=trust_remote_code, |
|
|
**kwargs |
|
|
) |
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
if os.path.isdir(pretrained_model_name_or_path): |
|
|
bin_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
|
|
else: |
|
|
from huggingface_hub import hf_hub_download |
|
|
bin_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="pytorch_model.bin", |
|
|
) |
|
|
sd = torch.load(bin_path, map_location="cpu", weights_only=True) |
|
|
prefixed_sd = {f"student.{k}": v for k, v in sd.items()} |
|
|
model.load_state_dict(prefixed_sd, strict=True) |
|
|
return model |
|
|
|