Cataract-ViT / src /utils.py
Decoder24's picture
Upload folder using huggingface_hub
a080b32 verified
# src/utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import seaborn as sns
from sklearn.metrics import confusion_matrix
plt.style.use('ggplot')
def save_model(epochs, model, optimizer, criterion, model_path):
"""
Fungsi untuk menyimpan checkpoint model.
"""
print(f"Menyimpan model ke {model_path}")
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': criterion,
}, model_path)
def save_plots(train_acc, valid_acc, train_loss, valid_loss, plot_path):
"""
Fungsi untuk menyimpan plot akurasi dan loss.
"""
print(f"Menyimpan plot ke {plot_path}")
os.makedirs(os.path.dirname(plot_path), exist_ok=True)
plt.figure(figsize=(10, 7))
plt.plot(train_acc, color='green', linestyle='-', label='train accuracy')
plt.plot(valid_acc, color='blue', linestyle='-', label='validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(f"{plot_path}/accuracy.png")
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', linestyle='-', label='train loss')
plt.plot(valid_loss, color='red', linestyle='-', label='validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig(f"{plot_path}/loss.png")
def save_confusion_matrix(y_true, y_pred, class_names, save_path):
"""
Menyimpan plot confusion matrix.
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=class_names,
yticklabels=class_names
)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix of Best Validation Model')
plt.savefig(save_path)
print(f"Confusion matrix disimpan di {save_path}")
class FocalLoss(nn.Module):
"""
Implementasi Focal Loss.
"""
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss