|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from tqdm import tqdm
|
|
|
import numpy as np
|
|
|
from sklearn.metrics import classification_report
|
|
|
from torch.utils.data import DataLoader, WeightedRandomSampler
|
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
import copy
|
|
|
|
|
|
|
|
|
from model import create_model
|
|
|
from dataset import get_dataloaders
|
|
|
from utils import save_model, save_plots, save_confusion_matrix, FocalLoss
|
|
|
|
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
DATA_DIR = 'data'
|
|
|
OUTPUT_DIR = 'outputs/new_model2'
|
|
|
IMAGE_SIZE = 224
|
|
|
BATCH_SIZE = 16
|
|
|
NUM_WORKERS = 4
|
|
|
EPOCHS = 50
|
|
|
LEARNING_RATE_HEAD = 1e-3
|
|
|
LEARNING_RATE_FINETUNE = 2e-5
|
|
|
WEIGHT_DECAY = 0.05
|
|
|
MODEL_NAME = 'best_model_final_TTA_Focal.pth'
|
|
|
NUM_CLASSES = 4
|
|
|
PATIENCE = 7
|
|
|
|
|
|
|
|
|
def train_one_epoch(model, dataloader, optimizer, criterion, device):
|
|
|
model.train()
|
|
|
running_loss, correct_predictions, total_samples = 0.0, 0, 0
|
|
|
|
|
|
for images, labels in tqdm(dataloader, total=len(dataloader), desc="Training"):
|
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
optimizer.zero_grad()
|
|
|
outputs = model(images)
|
|
|
loss = criterion(outputs, labels)
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
running_loss += loss.item() * images.size(0)
|
|
|
_, preds = torch.max(outputs, 1)
|
|
|
correct_predictions += torch.sum(preds == labels.data)
|
|
|
total_samples += labels.size(0)
|
|
|
|
|
|
epoch_loss = running_loss / total_samples
|
|
|
epoch_acc = (correct_predictions.double() / total_samples).item()
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
|
|
|
def validate_one_epoch_tta(model, dataloader, criterion, device):
|
|
|
"""
|
|
|
Validasi dengan Test-Time Augmentation (TTA).
|
|
|
Gunakan prediksi rata-rata dari beberapa augmentasi sederhana.
|
|
|
"""
|
|
|
model.eval()
|
|
|
running_loss, correct_predictions, total_samples = 0.0, 0, 0
|
|
|
all_preds, all_labels = [], []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for images, labels in tqdm(dataloader, total=len(dataloader), desc="Validating with TTA"):
|
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
|
|
|
|
|
|
outputs_list = []
|
|
|
outputs_list.append(model(images))
|
|
|
outputs_list.append(model(torch.flip(images, dims=[3])))
|
|
|
outputs_list.append(model(torch.flip(images, dims=[2])))
|
|
|
outputs_list.append(model(torch.rot90(images, k=1, dims=[2, 3])))
|
|
|
|
|
|
|
|
|
outputs_avg = torch.mean(
|
|
|
torch.stack([torch.softmax(out, dim=1) for out in outputs_list]), dim=0
|
|
|
)
|
|
|
|
|
|
loss = criterion(outputs_avg, labels)
|
|
|
running_loss += loss.item() * images.size(0)
|
|
|
|
|
|
_, preds = torch.max(outputs_avg, 1)
|
|
|
correct_predictions += torch.sum(preds == labels.data)
|
|
|
total_samples += labels.size(0)
|
|
|
|
|
|
all_preds.extend(preds.cpu().numpy())
|
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
|
|
epoch_loss = running_loss / total_samples
|
|
|
epoch_acc = (correct_predictions.double() / total_samples).item()
|
|
|
|
|
|
print("\n--- Laporan Klasifikasi Validasi (dengan TTA) ---")
|
|
|
print(classification_report(all_labels, all_preds, target_names=[str(i) for i in range(NUM_CLASSES)], zero_division=0))
|
|
|
|
|
|
return epoch_loss, epoch_acc, all_labels, all_preds
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
_, valid_loader, classes, train_dataset = get_dataloaders(DATA_DIR, BATCH_SIZE, IMAGE_SIZE, NUM_WORKERS)
|
|
|
|
|
|
print("\n--- Menyiapkan Balanced Sampler untuk Training ---")
|
|
|
class_counts = np.bincount(train_dataset.targets)
|
|
|
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
|
|
|
sample_weights = class_weights[train_dataset.targets]
|
|
|
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS)
|
|
|
|
|
|
model = create_model(num_classes=NUM_CLASSES, image_size=IMAGE_SIZE).to(DEVICE)
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.1).to(DEVICE)
|
|
|
print("Menggunakan CrossEntropyLoss dengan Label Smoothing (0.1).")
|
|
|
|
|
|
|
|
|
print("\n--- TAHAP 1: Melatih Classifier Head ---")
|
|
|
for param in model.parameters():
|
|
|
param.requires_grad = False
|
|
|
for param in model.head.parameters():
|
|
|
param.requires_grad = True
|
|
|
optimizer_head = optim.AdamW(model.head.parameters(), lr=LEARNING_RATE_HEAD, weight_decay=WEIGHT_DECAY)
|
|
|
for epoch in range(5):
|
|
|
print(f"Epoch Head {epoch+1}/5")
|
|
|
train_one_epoch(model, train_loader, optimizer_head, criterion, DEVICE)
|
|
|
validate_one_epoch_tta(model, valid_loader, criterion, DEVICE)
|
|
|
|
|
|
|
|
|
print("\n--- TAHAP 2: Fine-tuning Seluruh Model ---")
|
|
|
for param in model.parameters():
|
|
|
param.requires_grad = True
|
|
|
optimizer_finetune = optim.AdamW(model.parameters(), lr=LEARNING_RATE_FINETUNE, weight_decay=WEIGHT_DECAY)
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer_finetune, T_max=EPOCHS, eta_min=1e-7)
|
|
|
print("Menggunakan scheduler CosineAnnealingLR.")
|
|
|
|
|
|
history = {'train_loss': [], 'train_acc': [], 'valid_loss': [], 'valid_acc': []}
|
|
|
best_valid_acc = 0.0
|
|
|
best_labels, best_preds = None, None
|
|
|
best_model_wts = copy.deepcopy(model.state_dict())
|
|
|
patience_counter = 0
|
|
|
|
|
|
for epoch in range(EPOCHS):
|
|
|
print(f"Epoch {epoch+1}/{EPOCHS}")
|
|
|
|
|
|
train_loss, train_acc = train_one_epoch(model, train_loader, optimizer_finetune, criterion, DEVICE)
|
|
|
valid_loss, valid_acc, valid_labels, valid_preds = validate_one_epoch_tta(model, valid_loader, criterion, DEVICE)
|
|
|
|
|
|
scheduler.step(valid_loss)
|
|
|
|
|
|
print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
|
|
|
print(f" Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}")
|
|
|
|
|
|
history['train_loss'].append(train_loss)
|
|
|
history['train_acc'].append(train_acc)
|
|
|
history['valid_loss'].append(valid_loss)
|
|
|
history['valid_acc'].append(valid_acc)
|
|
|
|
|
|
|
|
|
if valid_acc > best_valid_acc:
|
|
|
print(f"Validasi akurasi meningkat dari {best_valid_acc:.4f} ke {valid_acc:.4f}. Menyimpan model...")
|
|
|
save_model(epoch, model, optimizer_finetune, criterion, f"{OUTPUT_DIR}/{MODEL_NAME}")
|
|
|
best_valid_acc = valid_acc
|
|
|
best_labels = valid_labels
|
|
|
best_preds = valid_preds
|
|
|
best_model_wts = copy.deepcopy(model.state_dict())
|
|
|
patience_counter = 0
|
|
|
else:
|
|
|
patience_counter += 1
|
|
|
if patience_counter >= PATIENCE:
|
|
|
print(f"Early stopping di epoch {epoch+1} karena tidak ada peningkatan validasi selama {PATIENCE} epoch.")
|
|
|
break
|
|
|
|
|
|
|
|
|
model.load_state_dict(best_model_wts)
|
|
|
save_plots(history['train_acc'], history['valid_acc'], history['train_loss'], history['valid_loss'], OUTPUT_DIR)
|
|
|
|
|
|
if best_labels and best_preds:
|
|
|
save_confusion_matrix(best_labels, best_preds, classes, f"{OUTPUT_DIR}/confusion_matrix.png")
|
|
|
|
|
|
print("\n--- Selesai ---")
|
|
|
print(f"Model terbaik disimpan di {OUTPUT_DIR}/{MODEL_NAME}")
|
|
|
|