File size: 7,848 Bytes
a080b32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# src/train.py (dengan TTA + FocalLoss + EarlyStopping)
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy
# Impor dari file lain dalam proyek
from model import create_model
from dataset import get_dataloaders
from utils import save_model, save_plots, save_confusion_matrix, FocalLoss
# --- 1. KONFIGURASI & HYPERPARAMETERS ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_DIR = 'data'
OUTPUT_DIR = 'outputs/new_model2'
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_WORKERS = 4
EPOCHS = 50
LEARNING_RATE_HEAD = 1e-3
LEARNING_RATE_FINETUNE = 2e-5
WEIGHT_DECAY = 0.05
MODEL_NAME = 'best_model_final_TTA_Focal.pth'
NUM_CLASSES = 4
PATIENCE = 7 # untuk EarlyStopping
# --- 2. TRAINING & VALIDASI ---
def train_one_epoch(model, dataloader, optimizer, criterion, device):
model.train()
running_loss, correct_predictions, total_samples = 0.0, 0, 0
for images, labels in tqdm(dataloader, total=len(dataloader), desc="Training"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
correct_predictions += torch.sum(preds == labels.data)
total_samples += labels.size(0)
epoch_loss = running_loss / total_samples
epoch_acc = (correct_predictions.double() / total_samples).item()
return epoch_loss, epoch_acc
def validate_one_epoch_tta(model, dataloader, criterion, device):
"""
Validasi dengan Test-Time Augmentation (TTA).
Gunakan prediksi rata-rata dari beberapa augmentasi sederhana.
"""
model.eval()
running_loss, correct_predictions, total_samples = 0.0, 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for images, labels in tqdm(dataloader, total=len(dataloader), desc="Validating with TTA"):
images, labels = images.to(device), labels.to(device)
# TTA: original, flip H, flip V, rotate 90
outputs_list = []
outputs_list.append(model(images)) # original
outputs_list.append(model(torch.flip(images, dims=[3]))) # horizontal flip
outputs_list.append(model(torch.flip(images, dims=[2]))) # vertical flip
outputs_list.append(model(torch.rot90(images, k=1, dims=[2, 3]))) # rotate 90°
# Rata-ratakan probabilitas
outputs_avg = torch.mean(
torch.stack([torch.softmax(out, dim=1) for out in outputs_list]), dim=0
)
loss = criterion(outputs_avg, labels)
running_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs_avg, 1)
correct_predictions += torch.sum(preds == labels.data)
total_samples += labels.size(0)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
epoch_loss = running_loss / total_samples
epoch_acc = (correct_predictions.double() / total_samples).item()
print("\n--- Laporan Klasifikasi Validasi (dengan TTA) ---")
print(classification_report(all_labels, all_preds, target_names=[str(i) for i in range(NUM_CLASSES)], zero_division=0))
return epoch_loss, epoch_acc, all_labels, all_preds
# --- 3. SCRIPT UTAMA ---
if __name__ == '__main__':
_, valid_loader, classes, train_dataset = get_dataloaders(DATA_DIR, BATCH_SIZE, IMAGE_SIZE, NUM_WORKERS)
print("\n--- Menyiapkan Balanced Sampler untuk Training ---")
class_counts = np.bincount(train_dataset.targets)
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = class_weights[train_dataset.targets]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS)
model = create_model(num_classes=NUM_CLASSES, image_size=IMAGE_SIZE).to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1).to(DEVICE)
print("Menggunakan CrossEntropyLoss dengan Label Smoothing (0.1).")
# --- TAHAP 1: FREEZE BACKBONE, LATIH HEAD ---
print("\n--- TAHAP 1: Melatih Classifier Head ---")
for param in model.parameters():
param.requires_grad = False
for param in model.head.parameters():
param.requires_grad = True
optimizer_head = optim.AdamW(model.head.parameters(), lr=LEARNING_RATE_HEAD, weight_decay=WEIGHT_DECAY)
for epoch in range(5):
print(f"Epoch Head {epoch+1}/5")
train_one_epoch(model, train_loader, optimizer_head, criterion, DEVICE)
validate_one_epoch_tta(model, valid_loader, criterion, DEVICE)
# --- TAHAP 2: UNFREEZE & FINE-TUNE SELURUH MODEL ---
print("\n--- TAHAP 2: Fine-tuning Seluruh Model ---")
for param in model.parameters():
param.requires_grad = True
optimizer_finetune = optim.AdamW(model.parameters(), lr=LEARNING_RATE_FINETUNE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer_finetune, T_max=EPOCHS, eta_min=1e-7)
print("Menggunakan scheduler CosineAnnealingLR.")
history = {'train_loss': [], 'train_acc': [], 'valid_loss': [], 'valid_acc': []}
best_valid_acc = 0.0
best_labels, best_preds = None, None
best_model_wts = copy.deepcopy(model.state_dict())
patience_counter = 0
for epoch in range(EPOCHS):
print(f"Epoch {epoch+1}/{EPOCHS}")
train_loss, train_acc = train_one_epoch(model, train_loader, optimizer_finetune, criterion, DEVICE)
valid_loss, valid_acc, valid_labels, valid_preds = validate_one_epoch_tta(model, valid_loader, criterion, DEVICE)
scheduler.step(valid_loss)
print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f" Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}")
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['valid_loss'].append(valid_loss)
history['valid_acc'].append(valid_acc)
# EarlyStopping check
if valid_acc > best_valid_acc:
print(f"Validasi akurasi meningkat dari {best_valid_acc:.4f} ke {valid_acc:.4f}. Menyimpan model...")
save_model(epoch, model, optimizer_finetune, criterion, f"{OUTPUT_DIR}/{MODEL_NAME}")
best_valid_acc = valid_acc
best_labels = valid_labels
best_preds = valid_preds
best_model_wts = copy.deepcopy(model.state_dict())
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= PATIENCE:
print(f"Early stopping di epoch {epoch+1} karena tidak ada peningkatan validasi selama {PATIENCE} epoch.")
break
# Load best model
model.load_state_dict(best_model_wts)
save_plots(history['train_acc'], history['valid_acc'], history['train_loss'], history['valid_loss'], OUTPUT_DIR)
if best_labels and best_preds:
save_confusion_matrix(best_labels, best_preds, classes, f"{OUTPUT_DIR}/confusion_matrix.png")
print("\n--- Selesai ---")
print(f"Model terbaik disimpan di {OUTPUT_DIR}/{MODEL_NAME}")
|