lofi-bytes / utilities /run_model.py
amosyou's picture
feat: add lofi-bytes-api and gradio app
7116323
import torch
import time
from .constants import *
from utilities.device import get_device
from .lr_scheduling import get_lr
from dataset.e_piano import compute_epiano_accuracy
# train_epoch
def train_epoch(cur_epoch, model, dataloader, loss, opt, lr_scheduler=None, print_modulus=1):
"""
----------
Author: Damon Gwinn
----------
Trains a single model epoch
----------
"""
out = -1
model.train()
for batch_num, batch in enumerate(dataloader):
time_before = time.time()
opt.zero_grad()
x = batch[0].to(get_device())
tgt = batch[1].to(get_device())
y = model(x)
y = y.reshape(y.shape[0] * y.shape[1], -1)
tgt = tgt.flatten()
out = loss.forward(y, tgt)
out.backward()
opt.step()
if(lr_scheduler is not None):
lr_scheduler.step()
time_after = time.time()
time_took = time_after - time_before
if((batch_num+1) % print_modulus == 0):
print(SEPERATOR)
print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader))
print("LR:", get_lr(opt))
print("Train loss:", float(out))
print("")
print("Time (s):", time_took)
print(SEPERATOR)
print("")
return
# eval_model
def eval_model(model, dataloader, loss):
"""
----------
Author: Damon Gwinn
----------
Evaluates the model and prints the average loss and accuracy
----------
"""
model.eval()
avg_acc = -1
avg_loss = -1
with torch.set_grad_enabled(False):
n_test = len(dataloader)
sum_loss = 0.0
sum_acc = 0.0
for batch in dataloader:
x = batch[0].to(get_device())
tgt = batch[1].to(get_device())
y = model(x)
sum_acc += float(compute_epiano_accuracy(y, tgt))
y = y.reshape(y.shape[0] * y.shape[1], -1)
tgt = tgt.flatten()
out = loss.forward(y, tgt)
sum_loss += float(out)
avg_loss = sum_loss / n_test
avg_acc = sum_acc / n_test
return avg_loss, avg_acc