File size: 2,819 Bytes
a080b32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# src/utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import seaborn as sns
from sklearn.metrics import confusion_matrix
plt.style.use('ggplot')
def save_model(epochs, model, optimizer, criterion, model_path):
"""
Fungsi untuk menyimpan checkpoint model.
"""
print(f"Menyimpan model ke {model_path}")
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': criterion,
}, model_path)
def save_plots(train_acc, valid_acc, train_loss, valid_loss, plot_path):
"""
Fungsi untuk menyimpan plot akurasi dan loss.
"""
print(f"Menyimpan plot ke {plot_path}")
os.makedirs(os.path.dirname(plot_path), exist_ok=True)
plt.figure(figsize=(10, 7))
plt.plot(train_acc, color='green', linestyle='-', label='train accuracy')
plt.plot(valid_acc, color='blue', linestyle='-', label='validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(f"{plot_path}/accuracy.png")
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', linestyle='-', label='train loss')
plt.plot(valid_loss, color='red', linestyle='-', label='validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig(f"{plot_path}/loss.png")
def save_confusion_matrix(y_true, y_pred, class_names, save_path):
"""
Menyimpan plot confusion matrix.
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=class_names,
yticklabels=class_names
)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix of Best Validation Model')
plt.savefig(save_path)
print(f"Confusion matrix disimpan di {save_path}")
class FocalLoss(nn.Module):
"""
Implementasi Focal Loss.
"""
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss |