# src/utils.py import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt import os import seaborn as sns from sklearn.metrics import confusion_matrix plt.style.use('ggplot') def save_model(epochs, model, optimizer, criterion, model_path): """ Fungsi untuk menyimpan checkpoint model. """ print(f"Menyimpan model ke {model_path}") os.makedirs(os.path.dirname(model_path), exist_ok=True) torch.save({ 'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': criterion, }, model_path) def save_plots(train_acc, valid_acc, train_loss, valid_loss, plot_path): """ Fungsi untuk menyimpan plot akurasi dan loss. """ print(f"Menyimpan plot ke {plot_path}") os.makedirs(os.path.dirname(plot_path), exist_ok=True) plt.figure(figsize=(10, 7)) plt.plot(train_acc, color='green', linestyle='-', label='train accuracy') plt.plot(valid_acc, color='blue', linestyle='-', label='validation accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend() plt.savefig(f"{plot_path}/accuracy.png") plt.figure(figsize=(10, 7)) plt.plot(train_loss, color='orange', linestyle='-', label='train loss') plt.plot(valid_loss, color='red', linestyle='-', label='validation loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.savefig(f"{plot_path}/loss.png") def save_confusion_matrix(y_true, y_pred, class_names, save_path): """ Menyimpan plot confusion matrix. """ cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(10, 8)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names ) plt.xlabel('Predicted Label') plt.ylabel('True Label') plt.title('Confusion Matrix of Best Validation Model') plt.savefig(save_path) print(f"Confusion matrix disimpan di {save_path}") class FocalLoss(nn.Module): """ Implementasi Focal Loss. """ def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss