Spaces:
Sleeping
Sleeping
| import time | |
| import torch | |
| from torch import nn as nn, optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import config | |
| from dataset import SpeechEmotionDataset, extract_mfcc | |
| from model import SpeechEmotionModel | |
| from utils import load_checkpoint, save_checkpoint | |
| def train_fn(model, loader, opt, criterion, epoch): | |
| loop = tqdm(loader, leave=True) | |
| model.train() | |
| epoch_loss = 0.0 | |
| for idx, (feature, label) in enumerate(loop): | |
| total_acc, total_count = 0, 0 | |
| feature = feature.to(config.DEVICE) | |
| label = label.to(config.DEVICE) | |
| opt.zero_grad() | |
| feature = torch.unsqueeze(feature, dim=2) | |
| predicted_label = model(feature) | |
| loss = criterion(predicted_label, label) | |
| epoch_loss += loss.item() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) | |
| opt.step() | |
| total_acc += (predicted_label.argmax(1) == label.argmax(1)).sum().item() | |
| total_count += label.size(0) | |
| loop.set_postfix({"epoch": epoch, "loss": epoch_loss / len(loader), "accuracy": total_acc / total_count}) | |
| def main(): | |
| model = SpeechEmotionModel().to(config.DEVICE) | |
| opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), ) | |
| criterion = nn.CrossEntropyLoss() | |
| # if config.LOAD_MODEL: | |
| # load_checkpoint( | |
| # config.CHECKPOINT, model, opt, config.LEARNING_RATE, | |
| # ) | |
| train_dataset = SpeechEmotionDataset(root_dir=config.TRAIN_DIR) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| ) | |
| val_dataset = SpeechEmotionDataset(root_dir=config.VAL_DIR) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| ) | |
| total_accu = None | |
| # scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.5) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=2, verbose=True) | |
| for epoch in range(config.NUM_EPOCHS): | |
| epoch_start_time = time.time() | |
| train_fn( | |
| model, train_loader, opt, criterion, epoch | |
| ) | |
| accu_val, loss_val = evaluate(model, criterion, val_loader) | |
| # if total_accu is not None and total_accu > accu_val: | |
| # scheduler.step() | |
| # else: | |
| # total_accu = accu_val | |
| scheduler.step(loss_val) | |
| print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+") | |
| print( | |
| "| end of epoch: {:3d} | time: {:6.2f}s | val_loss: {:8.3f} | " | |
| "val_accuracy: {:8.3f} |".format( | |
| epoch, time.time() - epoch_start_time, loss_val, accu_val | |
| ) | |
| ) | |
| print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+") | |
| if config.SAVE_MODEL: | |
| save_checkpoint(model, opt, filename=config.CHECKPOINT) | |
| def test(): | |
| model = SpeechEmotionModel().to(config.DEVICE) | |
| opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), ) | |
| criterion = nn.CrossEntropyLoss() | |
| if config.LOAD_MODEL: | |
| load_checkpoint( | |
| config.CHECKPOINT, model, opt, config.LEARNING_RATE, | |
| ) | |
| val_dataset = SpeechEmotionDataset(root_dir=config.VAL_DIR) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| ) | |
| label = {0: 'anger', 1: 'disgust', 2: 'fear', 3: 'happy', 4: 'neutral', 5: 'ps', 6: 'sad'} | |
| mfcc = extract_mfcc("uploads/OAF_bar_fear.wav") | |
| mfcc = torch.from_numpy(mfcc) | |
| mfcc = mfcc.to(config.DEVICE) | |
| mfcc = torch.unsqueeze(mfcc, dim=1) | |
| mfcc = torch.unsqueeze(mfcc, dim=0) | |
| model.eval() | |
| y_pred = model(mfcc) | |
| print(torch.argmax(y_pred)) | |
| print(val_dataset.class_to_idx) | |
| print(evaluate(model, criterion, val_loader)) | |
| def evaluate(model, criterion, dataloader): | |
| model.eval() | |
| total_correct = 0 | |
| total_samples = 0 | |
| total_loss = 0.0 | |
| with torch.no_grad(): | |
| for inputs, labels in dataloader: | |
| inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE) | |
| inputs = torch.unsqueeze(inputs, dim=2) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total_correct += (predicted == labels.argmax(1)).sum().item() | |
| total_samples += labels.size(0) | |
| accuracy = total_correct / total_samples | |
| average_loss = total_loss / len(dataloader) | |
| return accuracy, average_loss | |
| if __name__ == "__main__": | |
| test() | |