File size: 7,416 Bytes
04fefb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
Unified training script for deterministic segmentation baselines.
Uses official libraries: smp (U-Net, U-Net++), MONAI (Attention U-Net), smp MAnet (TransUNet-style).
"""
import argparse
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from dataset import LIDCFlatDataset
from models import get_model


class DiceBCELoss(nn.Module):
    """Combined Dice + BCE loss for binary segmentation."""
    def __init__(self, dice_weight=0.5, bce_weight=0.5):
        super().__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.bce = nn.BCEWithLogitsLoss()
    
    def forward(self, logits, targets):
        # BCE loss
        bce_loss = self.bce(logits, targets)
        
        # Dice loss
        probs = torch.sigmoid(logits)
        smooth = 1e-5
        intersection = (probs * targets).sum(dim=(2, 3))
        dice = (2. * intersection + smooth) / (probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + smooth)
        dice_loss = 1 - dice.mean()
        
        return self.dice_weight * dice_loss + self.bce_weight * bce_loss


def compute_dice(pred, target):
    """Compute Dice coefficient for evaluation."""
    smooth = 1e-5
    pred = (torch.sigmoid(pred) > 0.5).float()
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)


def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_dice = 0
    
    for images, masks, _ in loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_dice += compute_dice(outputs, masks).item()
    
    return total_loss / len(loader), total_dice / len(loader)


@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    total_dice = 0
    
    for images, masks, _ in loader:
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        total_loss += loss.item()
        total_dice += compute_dice(outputs, masks).item()
    
    return total_loss / len(loader), total_dice / len(loader)


def main():
    parser = argparse.ArgumentParser(description="Train deterministic segmentation baseline")
    parser.add_argument("--model", type=str, required=True,
                        choices=["unet", "attention_unet", "unetpp", "transunet", "nnunet"],
                        help="Model architecture")
    parser.add_argument("--data_dir", type=str, default="data/flat_train",
                        help="Path to flat training data")
    parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--val_split", type=float, default=0.1, help="Validation split ratio")
    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Checkpoint directory")
    parser.add_argument("--num_workers", type=int, default=4, help="DataLoader workers")
    parser.add_argument("--patience", type=int, default=20, help="Early stopping patience")
    args = parser.parse_args()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    
    # Create dataset
    full_dataset = LIDCFlatDataset(args.data_dir, augment=True)
    
    # Split into train/val
    val_size = int(len(full_dataset) * args.val_split)
    train_size = len(full_dataset) - val_size
    
    generator = torch.Generator().manual_seed(42)
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)
    
    # Disable augmentation for validation
    val_dataset_no_aug = LIDCFlatDataset(args.data_dir, augment=False)
    val_indices = val_dataset.indices
    val_dataset_final = torch.utils.data.Subset(val_dataset_no_aug, val_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, 
                              shuffle=True, num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset_final, batch_size=args.batch_size, 
                            shuffle=False, num_workers=args.num_workers, pin_memory=True)
    
    print(f"Train: {train_size}, Val: {val_size}")
    
    # Create model
    model = get_model(args.model, in_channels=1, num_classes=1)
    model = model.to(device)
    
    # Count parameters
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model: {args.model}, Parameters: {n_params:,}")
    
    # Loss, optimizer, scheduler
    criterion = DiceBCELoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    
    # Training loop
    best_val_dice = 0
    patience_counter = 0
    
    checkpoint_path = os.path.join(args.checkpoint_dir, f"{args.model}_best.pth")
    
    print(f"\nTraining {args.model} for {args.epochs} epochs...")
    print("-" * 70)
    
    start_time = time.time()
    
    for epoch in range(1, args.epochs + 1):
        train_loss, train_dice = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_dice = validate(model, val_loader, criterion, device)
        scheduler.step()
        
        # Save best model
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            patience_counter = 0
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_dice": val_dice,
                "model_name": args.model,
            }, checkpoint_path)
        else:
            patience_counter += 1
        
        # Print progress every 5 epochs or at start/end
        if epoch % 5 == 0 or epoch == 1 or epoch == args.epochs:
            elapsed = time.time() - start_time
            lr = optimizer.param_groups[0]['lr']
            print(f"Epoch {epoch:3d}/{args.epochs} | "
                  f"Train Loss: {train_loss:.4f} Dice: {train_dice:.4f} | "
                  f"Val Loss: {val_loss:.4f} Dice: {val_dice:.4f} | "
                  f"Best: {best_val_dice:.4f} | "
                  f"LR: {lr:.2e} | "
                  f"Time: {elapsed:.0f}s")
        
        # Early stopping
        if patience_counter >= args.patience:
            print(f"\nEarly stopping at epoch {epoch} (patience={args.patience})")
            break
    
    total_time = time.time() - start_time
    print("-" * 70)
    print(f"Training complete! Best val Dice: {best_val_dice:.4f}")
    print(f"Total time: {total_time:.0f}s ({total_time/60:.1f}min)")
    print(f"Checkpoint saved: {checkpoint_path}")


if __name__ == "__main__":
    main()