|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import matplotlib.pyplot as plt
|
|
|
import os
|
|
|
import seaborn as sns
|
|
|
from sklearn.metrics import confusion_matrix
|
|
|
|
|
|
plt.style.use('ggplot')
|
|
|
|
|
|
def save_model(epochs, model, optimizer, criterion, model_path):
|
|
|
"""
|
|
|
Fungsi untuk menyimpan checkpoint model.
|
|
|
"""
|
|
|
print(f"Menyimpan model ke {model_path}")
|
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
|
torch.save({
|
|
|
'epoch': epochs,
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'loss': criterion,
|
|
|
}, model_path)
|
|
|
|
|
|
def save_plots(train_acc, valid_acc, train_loss, valid_loss, plot_path):
|
|
|
"""
|
|
|
Fungsi untuk menyimpan plot akurasi dan loss.
|
|
|
"""
|
|
|
print(f"Menyimpan plot ke {plot_path}")
|
|
|
os.makedirs(os.path.dirname(plot_path), exist_ok=True)
|
|
|
|
|
|
plt.figure(figsize=(10, 7))
|
|
|
plt.plot(train_acc, color='green', linestyle='-', label='train accuracy')
|
|
|
plt.plot(valid_acc, color='blue', linestyle='-', label='validation accuracy')
|
|
|
plt.xlabel('Epochs')
|
|
|
plt.ylabel('Accuracy')
|
|
|
plt.legend()
|
|
|
plt.savefig(f"{plot_path}/accuracy.png")
|
|
|
|
|
|
plt.figure(figsize=(10, 7))
|
|
|
plt.plot(train_loss, color='orange', linestyle='-', label='train loss')
|
|
|
plt.plot(valid_loss, color='red', linestyle='-', label='validation loss')
|
|
|
plt.xlabel('Epochs')
|
|
|
plt.ylabel('Loss')
|
|
|
plt.legend()
|
|
|
plt.savefig(f"{plot_path}/loss.png")
|
|
|
|
|
|
def save_confusion_matrix(y_true, y_pred, class_names, save_path):
|
|
|
"""
|
|
|
Menyimpan plot confusion matrix.
|
|
|
"""
|
|
|
cm = confusion_matrix(y_true, y_pred)
|
|
|
plt.figure(figsize=(10, 8))
|
|
|
sns.heatmap(
|
|
|
cm,
|
|
|
annot=True,
|
|
|
fmt='d',
|
|
|
cmap='Blues',
|
|
|
xticklabels=class_names,
|
|
|
yticklabels=class_names
|
|
|
)
|
|
|
plt.xlabel('Predicted Label')
|
|
|
plt.ylabel('True Label')
|
|
|
plt.title('Confusion Matrix of Best Validation Model')
|
|
|
plt.savefig(save_path)
|
|
|
print(f"Confusion matrix disimpan di {save_path}")
|
|
|
|
|
|
class FocalLoss(nn.Module):
|
|
|
"""
|
|
|
Implementasi Focal Loss.
|
|
|
"""
|
|
|
def __init__(self, alpha=1, gamma=2, reduction='mean'):
|
|
|
super(FocalLoss, self).__init__()
|
|
|
self.alpha = alpha
|
|
|
self.gamma = gamma
|
|
|
self.reduction = reduction
|
|
|
|
|
|
def forward(self, inputs, targets):
|
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
|
|
pt = torch.exp(-ce_loss)
|
|
|
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
|
|
|
|
|
|
if self.reduction == 'mean':
|
|
|
return focal_loss.mean()
|
|
|
elif self.reduction == 'sum':
|
|
|
return focal_loss.sum()
|
|
|
else:
|
|
|
return focal_loss |