speech_eMEOWtion / train.py
Tanishq
Upload 9 files
4dee9c4 verified
import time
import torch
from torch import nn as nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import config
from dataset import SpeechEmotionDataset, extract_mfcc
from model import SpeechEmotionModel
from utils import load_checkpoint, save_checkpoint
def train_fn(model, loader, opt, criterion, epoch):
loop = tqdm(loader, leave=True)
model.train()
epoch_loss = 0.0
for idx, (feature, label) in enumerate(loop):
total_acc, total_count = 0, 0
feature = feature.to(config.DEVICE)
label = label.to(config.DEVICE)
opt.zero_grad()
feature = torch.unsqueeze(feature, dim=2)
predicted_label = model(feature)
loss = criterion(predicted_label, label)
epoch_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
opt.step()
total_acc += (predicted_label.argmax(1) == label.argmax(1)).sum().item()
total_count += label.size(0)
loop.set_postfix({"epoch": epoch, "loss": epoch_loss / len(loader), "accuracy": total_acc / total_count})
def main():
model = SpeechEmotionModel().to(config.DEVICE)
opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), )
criterion = nn.CrossEntropyLoss()
# if config.LOAD_MODEL:
# load_checkpoint(
# config.CHECKPOINT, model, opt, config.LEARNING_RATE,
# )
train_dataset = SpeechEmotionDataset(root_dir=config.TRAIN_DIR)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
val_dataset = SpeechEmotionDataset(root_dir=config.VAL_DIR)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
total_accu = None
# scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=2, verbose=True)
for epoch in range(config.NUM_EPOCHS):
epoch_start_time = time.time()
train_fn(
model, train_loader, opt, criterion, epoch
)
accu_val, loss_val = evaluate(model, criterion, val_loader)
# if total_accu is not None and total_accu > accu_val:
# scheduler.step()
# else:
# total_accu = accu_val
scheduler.step(loss_val)
print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+")
print(
"| end of epoch: {:3d} | time: {:6.2f}s | val_loss: {:8.3f} | "
"val_accuracy: {:8.3f} |".format(
epoch, time.time() - epoch_start_time, loss_val, accu_val
)
)
print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+")
if config.SAVE_MODEL:
save_checkpoint(model, opt, filename=config.CHECKPOINT)
def test():
model = SpeechEmotionModel().to(config.DEVICE)
opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), )
criterion = nn.CrossEntropyLoss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT, model, opt, config.LEARNING_RATE,
)
val_dataset = SpeechEmotionDataset(root_dir=config.VAL_DIR)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
label = {0: 'anger', 1: 'disgust', 2: 'fear', 3: 'happy', 4: 'neutral', 5: 'ps', 6: 'sad'}
mfcc = extract_mfcc("uploads/OAF_bar_fear.wav")
mfcc = torch.from_numpy(mfcc)
mfcc = mfcc.to(config.DEVICE)
mfcc = torch.unsqueeze(mfcc, dim=1)
mfcc = torch.unsqueeze(mfcc, dim=0)
model.eval()
y_pred = model(mfcc)
print(torch.argmax(y_pred))
print(val_dataset.class_to_idx)
print(evaluate(model, criterion, val_loader))
def evaluate(model, criterion, dataloader):
model.eval()
total_correct = 0
total_samples = 0
total_loss = 0.0
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
inputs = torch.unsqueeze(inputs, dim=2)
outputs = model(inputs)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_correct += (predicted == labels.argmax(1)).sum().item()
total_samples += labels.size(0)
accuracy = total_correct / total_samples
average_loss = total_loss / len(dataloader)
return accuracy, average_loss
if __name__ == "__main__":
test()