# coding: utf-8 import torch from .architectures import ( EmotionMamba, PersonalityMamba, FusionTransformer, ) def load_pretrained_emotion_encoder(checkpoint_path, device): emotion_model = EmotionMamba( input_dim_emotion=1024, input_dim_personality=1024, hidden_dim=256, out_features=128, mamba_layer_number=2, dropout=0.1 ).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint emotion_model.load_state_dict(state_dict) def extract_features(inputs, lengths): features = emotion_model.emo_proj(inputs) for block in emotion_model.emotion_encoder: features = block(features) return features emotion_model.extract_features = extract_features emotion_model.eval() return emotion_model def load_pretrained_personality_encoder(checkpoint_path, device): personality_model = PersonalityMamba( input_dim_emotion=1024, input_dim_personality=1024, hidden_dim=64, out_features=256, mamba_layer_number=3, dropout=0.1).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) personality_model.load_state_dict(checkpoint) def extract_features(inputs, lengths): features = personality_model.per_proj(inputs) for block in personality_model.personality_encoder: features = block(features, features, features) return features personality_model.extract_features = extract_features personality_model.eval() return personality_model def load_fusion_model( fusion_checkpoint_path: str, emotion_encoder_checkpoint: str, personality_encoder_checkpoint: str, device: str = "cpu", ): device = torch.device(device) emotion_encoder = load_pretrained_emotion_encoder(emotion_encoder_checkpoint, device) personality_encoder = load_pretrained_personality_encoder(personality_encoder_checkpoint, device) checkpoint = torch.load(fusion_checkpoint_path, map_location=device) fusion_model = FusionTransformer( emo_model=emotion_encoder, per_model=personality_encoder, hidden_dim=128, out_features=64, tr_layer_number=3, num_transformer_heads=16, dropout=0.1 ).to(device) fusion_model.load_state_dict(checkpoint) fusion_model.eval() return fusion_model, device