LSDGNN_ICL / trainer.py
LiXinran1's picture
Upload 33 files
26e4a00 verified
import numpy as np, argparse, time, pickle, random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from dataloader import IEMOCAPDataset
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
precision_recall_fscore_support
from utils import person_embed
from tqdm import tqdm
import json
def train_or_eval_model(model, loss_function, dataloader,epoch, cuda, args, optimizer=None, train=False):
losses, preds, labels = [], [], []
scores, vids = [], []
assert not train or optimizer != None
if train:
model.train()
# dataloader = tqdm(dataloader)
else:
model.eval()
cnt = 0
for data in dataloader:
if train:
optimizer.zero_grad()
# text_ids, text_feature, speaker_ids, labels, umask = [d.cuda() for d in data] if cuda else data
features, label, adj_1, adj_2, s_mask, s_mask_onehot,lengths, speakers, utterances = data
# speaker_vec = person_embed(speaker_ids, person_vec)
if cuda:
features = features.cuda()
label = label.cuda()
adj_1 = adj_1.cuda()
adj_2 = adj_2.cuda()
s_mask = s_mask.cuda()
s_mask_onehot = s_mask_onehot.cuda()
lengths = lengths.cuda()
# print(speakers)
log_prob, diff_loss = model(features, adj_1, adj_2, s_mask, s_mask_onehot, lengths) # (B, N, C)
# print(label)
loss = loss_function(log_prob.permute(0,2,1), label)+ diff_loss
'''
# print(speakers)
log_prob = model(features, adj_1, adj_2, s_mask, s_mask_onehot, lengths) # (B, N, C)
# print(label)
loss = loss_function(log_prob.permute(0,2,1), label)
'''
label = label.cpu().numpy().tolist()
pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist()
preds += pred
labels += label
losses.append(loss.item())
if train:
loss_val = loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.tensorboard:
for param in model.named_parameters():
writer.add_histogram(param[0], param[1].grad, epoch)
optimizer.step()
if preds != []:
new_preds = []
new_labels = []
for i,label in enumerate(labels):
for j,l in enumerate(label):
if l != -1:
new_labels.append(l)
new_preds.append(preds[i][j])
else:
return float('nan'), float('nan'), [], [], float('nan'), [], [], [], [], []
# print(preds.tolist())
# print(labels.tolist())
avg_loss = round(np.sum(losses) / len(losses), 4)
avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
f1_per_class = f1_score(new_labels, new_preds, average=None) # List of F1 scores for each class
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
return avg_loss, avg_accuracy, labels, preds, avg_fscore, f1_per_class, avg_macro_fscore
else:
avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
return avg_loss, avg_accuracy, labels, preds, avg_micro_fscore, avg_macro_fscore
def save_badcase(model, dataloader, cuda, args, speaker_vocab, label_vocab):
preds, labels = [], []
scores, vids = [], []
dialogs = []
speakers = []
model.eval()
for data in dataloader:
# text_ids, text_feature, speaker_ids, labels, umask = [d.cuda() for d in data] if cuda else data
features, label, adj,s_mask, s_mask_onehot,lengths, speaker, utterances = data
# speaker_vec = person_embed(speaker_ids, person_vec)
if cuda:
features = features.cuda()
label = label.cuda()
adj = adj.cuda()
s_mask_onehot = s_mask_onehot.cuda()
s_mask = s_mask.cuda()
lengths = lengths.cuda()
# print(speakers)
log_prob = model(features, adj,s_mask, s_mask_onehot, lengths) # (B, N, C)
label = label.cpu().numpy().tolist() # (B, N)
pred = torch.argmax(log_prob, dim = 2).cpu().numpy().tolist() # (B, N)
preds += pred
labels += label
dialogs += utterances
speakers += speaker
# finished here
if preds != []:
new_preds = []
new_labels = []
for i,label in enumerate(labels):
for j,l in enumerate(label):
if l != -1:
new_labels.append(l)
new_preds.append(preds[i][j])
else:
return
cases = []
for i,d in enumerate(dialogs):
case = []
for j,u in enumerate(d):
case.append({
'text': u,
'speaker': speaker_vocab['itos'][speakers[i][j]],
'label': label_vocab['itos'][labels[i][j]] if labels[i][j] != -1 else 'none',
'pred': label_vocab['itos'][preds[i][j]]
})
cases.append(case)
with open('badcase/%s.json'%(args.dataset_name), 'w', encoding='utf-8') as f:
json.dump(cases,f)
# print(preds.tolist())
# print(labels.tolist())
avg_accuracy = round(accuracy_score(new_labels, new_preds) * 100, 2)
if args.dataset_name in ['IEMOCAP', 'MELD', 'EmoryNLP']:
avg_fscore = round(f1_score(new_labels, new_preds, average='weighted') * 100, 2)
print('badcase saved')
print('test_f1', avg_fscore)
return
else:
avg_micro_fscore = round(f1_score(new_labels, new_preds, average='micro', labels=list(range(1, 7))) * 100, 2)
avg_macro_fscore = round(f1_score(new_labels, new_preds, average='macro') * 100, 2)
print('badcase saved')
print('test_micro_f1', avg_micro_fscore)
print('test_macro_f1', avg_macro_fscore)
return