|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
def train_step(model: torch.nn.Module,
|
|
|
dataloader: torch.utils.data.DataLoader,
|
|
|
loss_fn: torch.nn.Module,
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
device: torch.device):
|
|
|
"""
|
|
|
Melakukan satu epoch training.
|
|
|
|
|
|
Mengatur model ke mode training, melakukan forward pass,
|
|
|
menghitung loss, melakukan backpropagation, dan update weights.
|
|
|
"""
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
train_loss, train_acc = 0, 0
|
|
|
|
|
|
|
|
|
|
|
|
for X, y in tqdm(dataloader, desc="Training"):
|
|
|
|
|
|
X, y = X.to(device), y.to(device)
|
|
|
|
|
|
|
|
|
y_pred_logits = model(X)
|
|
|
|
|
|
|
|
|
loss = loss_fn(y_pred_logits, y)
|
|
|
train_loss += loss.item()
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
|
|
y_pred_class = torch.argmax(y_pred_logits, dim=1)
|
|
|
train_acc += (y_pred_class == y).sum().item() / len(y_pred_logits)
|
|
|
|
|
|
|
|
|
train_loss = train_loss / len(dataloader)
|
|
|
train_acc = train_acc / len(dataloader)
|
|
|
|
|
|
return train_loss, train_acc
|
|
|
|
|
|
def val_step(model: torch.nn.Module,
|
|
|
dataloader: torch.utils.data.DataLoader,
|
|
|
loss_fn: torch.nn.Module,
|
|
|
device: torch.device):
|
|
|
"""
|
|
|
Melakukan satu epoch validasi.
|
|
|
|
|
|
Mengatur model ke mode evaluasi, melakukan forward pass,
|
|
|
dan menghitung loss/akurasi. Tidak ada backpropagation.
|
|
|
"""
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
val_loss, val_acc = 0, 0
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
for X, y in tqdm(dataloader, desc="Validasi"):
|
|
|
|
|
|
X, y = X.to(device), y.to(device)
|
|
|
|
|
|
|
|
|
y_pred_logits = model(X)
|
|
|
|
|
|
|
|
|
loss = loss_fn(y_pred_logits, y)
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
|
|
|
y_pred_class = torch.argmax(y_pred_logits, dim=1)
|
|
|
val_acc += (y_pred_class == y).sum().item() / len(y_pred_logits)
|
|
|
|
|
|
|
|
|
val_loss = val_loss / len(dataloader)
|
|
|
val_acc = val_acc / len(dataloader)
|
|
|
|
|
|
return val_loss, val_acc |