| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader, Dataset |
| | import numpy as np |
| | from sklearn.metrics import f1_score, accuracy_score |
| | import argparse |
| | import os |
| | from tqdm import tqdm |
| |
|
| | from models.vision import VisionEmotionModel |
| | from models.audio import AudioEmotionModel |
| | from models.text import TextIntentModel |
| | from models.fusion import MultiModalFusion |
| |
|
| | class MultiModalDataset(Dataset): |
| | """ |
| | Dataset for multi-modal training with aligned vision, audio, text data. |
| | """ |
| | def __init__(self, data_dir, split='train'): |
| | self.data_dir = data_dir |
| | self.split = split |
| | |
| | |
| | self.samples = self.load_samples() |
| |
|
| | def load_samples(self): |
| | |
| | |
| | return [] |
| |
|
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self, idx): |
| | sample = self.samples[idx] |
| | return { |
| | 'vision': sample['vision'], |
| | 'audio': sample['audio'], |
| | 'text': sample['text'], |
| | 'emotion': sample['emotion'], |
| | 'intent': sample['intent'], |
| | 'engagement': sample['engagement'], |
| | 'confidence': sample['confidence'] |
| | } |
| |
|
| | def train_epoch(model, dataloader, optimizer, criterion, device): |
| | model.train() |
| | total_loss = 0 |
| | emotion_preds, emotion_labels = [], [] |
| | intent_preds, intent_labels = [], [] |
| |
|
| | for batch in tqdm(dataloader, desc="Training"): |
| | |
| | vision = batch['vision'].to(device) |
| | audio = batch['audio'].to(device) |
| | text_input_ids = batch['text']['input_ids'].to(device) |
| | text_attention_mask = batch['text']['attention_mask'].to(device) |
| |
|
| | emotion_labels_batch = batch['emotion'].to(device) |
| | intent_labels_batch = batch['intent'].to(device) |
| | engagement_labels = batch['engagement'].to(device) |
| | confidence_labels = batch['confidence'].to(device) |
| |
|
| | optimizer.zero_grad() |
| |
|
| | |
| | outputs = model(vision, audio, text_input_ids, text_attention_mask) |
| |
|
| | |
| | emotion_loss = criterion['emotion'](outputs['emotion'], emotion_labels_batch) |
| | intent_loss = criterion['intent'](outputs['intent'], intent_labels_batch) |
| | engagement_loss = criterion['engagement'](outputs['engagement'], engagement_labels) |
| | confidence_loss = criterion['confidence'](outputs['confidence'], confidence_labels) |
| |
|
| | |
| | loss = (emotion_loss + intent_loss + engagement_loss + confidence_loss) / 4 |
| |
|
| | loss.backward() |
| | optimizer.step() |
| |
|
| | total_loss += loss.item() |
| |
|
| | |
| | emotion_preds.extend(outputs['emotion'].argmax(dim=1).cpu().numpy()) |
| | emotion_labels.extend(emotion_labels_batch.cpu().numpy()) |
| | intent_preds.extend(outputs['intent'].argmax(dim=1).cpu().numpy()) |
| | intent_labels.extend(intent_labels_batch.cpu().numpy()) |
| |
|
| | |
| | emotion_acc = accuracy_score(emotion_labels, emotion_preds) |
| | emotion_f1 = f1_score(emotion_labels, emotion_preds, average='weighted') |
| | intent_acc = accuracy_score(intent_labels, intent_preds) |
| | intent_f1 = f1_score(intent_labels, intent_preds, average='weighted') |
| |
|
| | return total_loss / len(dataloader), emotion_acc, emotion_f1, intent_acc, intent_f1 |
| |
|
| | def validate_epoch(model, dataloader, criterion, device): |
| | model.eval() |
| | total_loss = 0 |
| | emotion_preds, emotion_labels = [], [] |
| |
|
| | with torch.no_grad(): |
| | for batch in tqdm(dataloader, desc="Validating"): |
| | vision = batch['vision'].to(device) |
| | audio = batch['audio'].to(device) |
| | text_input_ids = batch['text']['input_ids'].to(device) |
| | text_attention_mask = batch['text']['attention_mask'].to(device) |
| |
|
| | emotion_labels_batch = batch['emotion'].to(device) |
| | intent_labels_batch = batch['intent'].to(device) |
| | engagement_labels = batch['engagement'].to(device) |
| | confidence_labels = batch['confidence'].to(device) |
| |
|
| | outputs = model(vision, audio, text_input_ids, text_attention_mask) |
| |
|
| | emotion_loss = criterion['emotion'](outputs['emotion'], emotion_labels_batch) |
| | intent_loss = criterion['intent'](outputs['intent'], intent_labels_batch) |
| | engagement_loss = criterion['engagement'](outputs['engagement'], engagement_labels) |
| | confidence_loss = criterion['confidence'](outputs['confidence'], confidence_labels) |
| |
|
| | loss = (emotion_loss + intent_loss + engagement_loss + confidence_loss) / 4 |
| | total_loss += loss.item() |
| |
|
| | emotion_preds.extend(outputs['emotion'].argmax(dim=1).cpu().numpy()) |
| | emotion_labels.extend(emotion_labels_batch.cpu().numpy()) |
| |
|
| | emotion_acc = accuracy_score(emotion_labels, emotion_preds) |
| | emotion_f1 = f1_score(emotion_labels, emotion_preds, average='weighted') |
| |
|
| | return total_loss / len(dataloader), emotion_acc, emotion_f1 |
| |
|
| | def main(args): |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| |
|
| | |
| | vision_model = VisionEmotionModel(num_emotions=args.num_emotions) |
| | audio_model = AudioEmotionModel(num_emotions=args.num_emotions) |
| | text_model = TextIntentModel(num_intents=args.num_intents) |
| |
|
| | |
| | |
| | fusion_model = MultiModalFusion( |
| | vision_dim=768, |
| | audio_dim=128, |
| | text_dim=768, |
| | num_emotions=args.num_emotions, |
| | num_intents=args.num_intents |
| | ).to(device) |
| |
|
| | |
| | criterion = { |
| | 'emotion': nn.CrossEntropyLoss(), |
| | 'intent': nn.CrossEntropyLoss(), |
| | 'engagement': nn.MSELoss(), |
| | 'confidence': nn.MSELoss() |
| | } |
| |
|
| | optimizer = optim.Adam(fusion_model.parameters(), lr=args.lr) |
| | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) |
| |
|
| | |
| | train_dataset = MultiModalDataset(args.data_dir, 'train') |
| | val_dataset = MultiModalDataset(args.data_dir, 'val') |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
| | val_loader = DataLoader(val_dataset, batch_size=args.batch_size) |
| |
|
| | best_f1 = 0 |
| | for epoch in range(args.epochs): |
| | print(f"\nEpoch {epoch+1}/{args.epochs}") |
| |
|
| | train_loss, train_acc, train_f1, intent_acc, intent_f1 = train_epoch( |
| | fusion_model, train_loader, optimizer, criterion, device |
| | ) |
| |
|
| | val_loss, val_acc, val_f1 = validate_epoch(fusion_model, val_loader, criterion, device) |
| |
|
| | print(".4f") |
| | print(".4f") |
| |
|
| | scheduler.step() |
| |
|
| | |
| | if val_f1 > best_f1: |
| | best_f1 = val_f1 |
| | torch.save(fusion_model.state_dict(), os.path.join(args.output_dir, 'best_model.pth')) |
| |
|
| | print("Training completed!") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Train EMOTIA Multi-Modal Model") |
| | parser.add_argument('--data_dir', type=str, required=True, help='Path to preprocessed data') |
| | parser.add_argument('--output_dir', type=str, default='./models/checkpoints', help='Output directory') |
| | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') |
| | parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') |
| | parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') |
| | parser.add_argument('--num_emotions', type=int, default=7, help='Number of emotion classes') |
| | parser.add_argument('--num_intents', type=int, default=5, help='Number of intent classes') |
| |
|
| | args = parser.parse_args() |
| | main(args) |