File size: 3,228 Bytes
a6eed2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
from tqdm.auto import tqdm # Untuk progress bar yang bagus

def train_step(model: torch.nn.Module, 

               dataloader: torch.utils.data.DataLoader, 

               loss_fn: torch.nn.Module, 

               optimizer: torch.optim.Optimizer,

               device: torch.device):
    """

    Melakukan satu epoch training.

    

    Mengatur model ke mode training, melakukan forward pass,

    menghitung loss, melakukan backpropagation, dan update weights.

    """
    # 1. Set model ke mode training
    # Ini penting untuk mengaktifkan lapisan seperti Dropout dan BatchNorm
    model.train()
    
    # 2. Setup variabel pelacak loss dan akurasi
    train_loss, train_acc = 0, 0
    
    # 3. Loop melalui data loader
    # Gunakan tqdm untuk progress bar
    for X, y in tqdm(dataloader, desc="Training"):
        # Pindahkan data ke device (GPU jika ada)
        X, y = X.to(device), y.to(device)
        
        # 4. Forward pass
        y_pred_logits = model(X)
        
        # 5. Hitung loss
        loss = loss_fn(y_pred_logits, y)
        train_loss += loss.item() 
        
        # 6. Nol-kan gradien optimizer
        optimizer.zero_grad()
        
        # 7. Backpropagation
        loss.backward()
        
        # 8. Update weights
        optimizer.step()
        
        # 9. Hitung akurasi
        # Ambil kelas dengan probabilitas tertinggi
        y_pred_class = torch.argmax(y_pred_logits, dim=1)
        train_acc += (y_pred_class == y).sum().item() / len(y_pred_logits)
        
    # 10. Hitung rata-rata loss dan akurasi per epoch
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    
    return train_loss, train_acc

def val_step(model: torch.nn.Module, 

             dataloader: torch.utils.data.DataLoader, 

             loss_fn: torch.nn.Module,

             device: torch.device):
    """

    Melakukan satu epoch validasi.

    

    Mengatur model ke mode evaluasi, melakukan forward pass,

    dan menghitung loss/akurasi. Tidak ada backpropagation.

    """
    # 1. Set model ke mode evaluasi
    # Ini penting untuk menonaktifkan Dropout dan BatchNorm
    model.eval() 
    
    # 2. Setup variabel pelacak loss dan akurasi
    val_loss, val_acc = 0, 0
    
    # 3. Matikan perhitungan gradien
    # Ini menghemat memori dan komputasi
    with torch.no_grad():
        # 4. Loop melalui data loader
        for X, y in tqdm(dataloader, desc="Validasi"):
            # Pindahkan data ke device
            X, y = X.to(device), y.to(device)
            
            # 5. Forward pass
            y_pred_logits = model(X)
            
            # 6. Hitung loss
            loss = loss_fn(y_pred_logits, y)
            val_loss += loss.item()
            
            # 7. Hitung akurasi
            y_pred_class = torch.argmax(y_pred_logits, dim=1)
            val_acc += (y_pred_class == y).sum().item() / len(y_pred_logits)
            
    # 8. Hitung rata-rata loss dan akurasi per epoch
    val_loss = val_loss / len(dataloader)
    val_acc = val_acc / len(dataloader)
    
    return val_loss, val_acc