File size: 6,465 Bytes
97fcc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import time
import torch
import torch.nn as nn
from torch.optim import AdamW
from datasets import load_from_disk
import subprocess
import sys

# Import models
from src.models.resnet18_finetune import make_resnet18
from src.models.cnn_model import PlantCNN

# Import utils
from src.utils.config import load_config
from src.utils.metrics import accuracy, topk_accuracy
from src.train.early_stopping import EarlyStopping

# Import Dataloader
from src.DataLoader.dataloader import create_dataloader

def train_one_epoch(model, loader, criterion, opt, device):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    for inputs, labels in loader:
        inputs = inputs.to(device)
        
        # Loader might return one-hot labels. CrossEntropyLoss needs indices.
        if labels.ndim > 1:
            labels = labels.argmax(dim=1)
        labels = labels.to(device).long()

        opt.zero_grad(set_to_none=True)
        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        opt.step()

        batch_size = inputs.size(0)
        total_loss += loss.item() * batch_size
        total_correct += (logits.argmax(1) == labels).sum().item()
        total_samples += batch_size
    return total_loss / total_samples, total_correct / total_samples

@torch.no_grad()
def evaluate(model, loader, criterion, device, topk=5):
    model.eval()
    total_loss, total_correct, total_topk, total_samples = 0.0, 0, 0, 0
    for inputs, labels in loader:
        inputs = inputs.to(device)
        if labels.ndim > 1:
            labels = labels.argmax(dim=1)
        labels = labels.to(device).long()

        logits = model(inputs)
        loss = criterion(logits, labels)

        batch_size = inputs.size(0)
        total_loss += loss.item() * batch_size
        total_correct += (logits.argmax(1) == labels).sum().item()
        
        # Top-k
        topk_preds = logits.topk(topk, dim=1).indices
        total_topk += (topk_preds == labels.unsqueeze(1)).any(dim=1).sum().item()
        total_samples += batch_size
    return total_loss / total_samples, total_correct / total_samples, total_topk / total_samples

def main():
    print("[INFO] Starting Integration Training Pipeline")
    
    # 1. Config
    cfg = load_config()
    os.makedirs("checkpoints", exist_ok=True)

    # 2. ClearML
    try:
        from clearml import Task
        task = Task.init(project_name=cfg.get("project", "PlantDisease"), task_name=cfg.get("task_name", "model_training"))
        task.set_packages("./requirements.txt")
        task.execute_remotely(queue_name="default")
        task.connect(cfg)
        logger = task.get_logger()
        print("[INFO] ClearML Initialized")
    except ImportError:
        logger = None
        print("[INFO] ClearML not found, skipping logging")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[INFO] Device: {device}")

    data_path = cfg['data_path']
    if not os.path.exists(data_path):
        print(f"[WARN] Data path '{data_path}' not found.")
        print("[INFO] Attempting to run data processing script...")
        try:
            subprocess.check_call([sys.executable, "process_dataset.py"])
            print("[SUCCESS] Data processing complete.")
        except subprocess.CalledProcessError as e:
            print(f"[FATAL] Data processing failed: {e}")
            exit(1)

    # 3. Data
    print(f"[INFO] Loading data from {cfg['data_path']}")
    ds_dict = load_from_disk(cfg['data_path'])
    
    dl_train = create_dataloader(ds_dict['train'], cfg['batch_size'], cfg['train_samples_per_epoch'], True)
    dl_val = create_dataloader(ds_dict['validation'], cfg['batch_size'], cfg['val_samples_per_epoch'], False)
    dl_test = create_dataloader(ds_dict['test'], cfg['batch_size'], cfg['test_samples_per_epoch'], False)

    # 4. Model Selection & Optimizer Setup
    model_type = cfg.get('model_type', 'resnet18').lower()
    print(f"[INFO] Initializing model architecture: {model_type}")

    if model_type == 'resnet18':
        model = make_resnet18(num_classes=cfg['num_classes'])
        model = model.to(device)
        # For ResNet transfer learning, we typically only optimize the head
        opt = AdamW(model.fc.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
        print("[INFO] Optimizer configured for ResNet head only.")
    
    elif model_type == 'cnn':
        model = PlantCNN(num_classes=cfg['num_classes'], p_drop=cfg.get('dropout', 0.5))
        model = model.to(device)
        # For custom CNN, we optimize all parameters
        opt = AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
        print("[INFO] Optimizer configured for full CNN parameters.")
        
    else:
        raise ValueError(f"Unknown model_type in config: {model_type}. Must be 'resnet18' or 'cnn'.")

    # 5. Setup Loss & Stopper
    crit = nn.CrossEntropyLoss()
    stopper = EarlyStopping(patience=cfg['patience'], min_delta=cfg['min_delta'])
    
    # 6. Loop
    best_acc = 0.0
    for epoch in range(1, cfg['epochs'] + 1):
        train_loss, train_acc = train_one_epoch(model, dl_train, crit, opt, device)
        val_loss, val_acc, val_top5 = evaluate(model, dl_val, crit, device, topk=5)
        
        print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Top5: {val_top5:.3f}")
        
        if logger:
            logger.report_scalar("Loss", "train", train_loss, iteration=epoch)
            logger.report_scalar("Accuracy", "train", train_acc, iteration=epoch)
            logger.report_scalar("Loss", "val", val_loss, iteration=epoch)
            logger.report_scalar("Accuracy", "val", val_acc, iteration=epoch)
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "checkpoints/best_baseline.pt")
            
        if stopper.step(val_acc):
            print("Early stopping.")
            break
    
    if logger:
        print("[INFO] Uploading best model artifact to ClearML...")
        task.upload_artifact(name="best_model", artifact_object="checkpoints/best_baseline.pt")
        print("[SUCCESS] Model uploaded.")

if __name__ == "__main__":
    main()