SSCD / text /model_loader.py
dashakoryakovskaya's picture
Upload 4 files
939957c verified
Raw
History Blame Contribute Delete
2.53 kB
# coding: utf-8
import torch
from .architectures import (
EmotionMamba,
PersonalityMamba,
FusionTransformer,
)
def load_pretrained_emotion_encoder(checkpoint_path, device):
emotion_model = EmotionMamba(
input_dim_emotion=1024,
input_dim_personality=1024,
hidden_dim=256,
out_features=128,
mamba_layer_number=2,
dropout=0.1
).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint
emotion_model.load_state_dict(state_dict)
def extract_features(inputs, lengths):
features = emotion_model.emo_proj(inputs)
for block in emotion_model.emotion_encoder:
features = block(features)
return features
emotion_model.extract_features = extract_features
emotion_model.eval()
return emotion_model
def load_pretrained_personality_encoder(checkpoint_path, device):
personality_model = PersonalityMamba(
input_dim_emotion=1024,
input_dim_personality=1024,
hidden_dim=64,
out_features=256,
mamba_layer_number=3,
dropout=0.1).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
personality_model.load_state_dict(checkpoint)
def extract_features(inputs, lengths):
features = personality_model.per_proj(inputs)
for block in personality_model.personality_encoder:
features = block(features, features, features)
return features
personality_model.extract_features = extract_features
personality_model.eval()
return personality_model
def load_fusion_model(
fusion_checkpoint_path: str,
emotion_encoder_checkpoint: str,
personality_encoder_checkpoint: str,
device: str = "cpu",
):
device = torch.device(device)
emotion_encoder = load_pretrained_emotion_encoder(emotion_encoder_checkpoint, device)
personality_encoder = load_pretrained_personality_encoder(personality_encoder_checkpoint, device)
checkpoint = torch.load(fusion_checkpoint_path, map_location=device)
fusion_model = FusionTransformer(
emo_model=emotion_encoder,
per_model=personality_encoder,
hidden_dim=128,
out_features=64,
tr_layer_number=3,
num_transformer_heads=16,
dropout=0.1
).to(device)
fusion_model.load_state_dict(checkpoint)
fusion_model.eval()
return fusion_model, device