Spaces:
Sleeping
Sleeping
File size: 6,465 Bytes
97fcc90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import os
import time
import torch
import torch.nn as nn
from torch.optim import AdamW
from datasets import load_from_disk
import subprocess
import sys
# Import models
from src.models.resnet18_finetune import make_resnet18
from src.models.cnn_model import PlantCNN
# Import utils
from src.utils.config import load_config
from src.utils.metrics import accuracy, topk_accuracy
from src.train.early_stopping import EarlyStopping
# Import Dataloader
from src.DataLoader.dataloader import create_dataloader
def train_one_epoch(model, loader, criterion, opt, device):
model.train()
total_loss, total_correct, total_samples = 0.0, 0, 0
for inputs, labels in loader:
inputs = inputs.to(device)
# Loader might return one-hot labels. CrossEntropyLoss needs indices.
if labels.ndim > 1:
labels = labels.argmax(dim=1)
labels = labels.to(device).long()
opt.zero_grad(set_to_none=True)
logits = model(inputs)
loss = criterion(logits, labels)
loss.backward()
opt.step()
batch_size = inputs.size(0)
total_loss += loss.item() * batch_size
total_correct += (logits.argmax(1) == labels).sum().item()
total_samples += batch_size
return total_loss / total_samples, total_correct / total_samples
@torch.no_grad()
def evaluate(model, loader, criterion, device, topk=5):
model.eval()
total_loss, total_correct, total_topk, total_samples = 0.0, 0, 0, 0
for inputs, labels in loader:
inputs = inputs.to(device)
if labels.ndim > 1:
labels = labels.argmax(dim=1)
labels = labels.to(device).long()
logits = model(inputs)
loss = criterion(logits, labels)
batch_size = inputs.size(0)
total_loss += loss.item() * batch_size
total_correct += (logits.argmax(1) == labels).sum().item()
# Top-k
topk_preds = logits.topk(topk, dim=1).indices
total_topk += (topk_preds == labels.unsqueeze(1)).any(dim=1).sum().item()
total_samples += batch_size
return total_loss / total_samples, total_correct / total_samples, total_topk / total_samples
def main():
print("[INFO] Starting Integration Training Pipeline")
# 1. Config
cfg = load_config()
os.makedirs("checkpoints", exist_ok=True)
# 2. ClearML
try:
from clearml import Task
task = Task.init(project_name=cfg.get("project", "PlantDisease"), task_name=cfg.get("task_name", "model_training"))
task.set_packages("./requirements.txt")
task.execute_remotely(queue_name="default")
task.connect(cfg)
logger = task.get_logger()
print("[INFO] ClearML Initialized")
except ImportError:
logger = None
print("[INFO] ClearML not found, skipping logging")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Device: {device}")
data_path = cfg['data_path']
if not os.path.exists(data_path):
print(f"[WARN] Data path '{data_path}' not found.")
print("[INFO] Attempting to run data processing script...")
try:
subprocess.check_call([sys.executable, "process_dataset.py"])
print("[SUCCESS] Data processing complete.")
except subprocess.CalledProcessError as e:
print(f"[FATAL] Data processing failed: {e}")
exit(1)
# 3. Data
print(f"[INFO] Loading data from {cfg['data_path']}")
ds_dict = load_from_disk(cfg['data_path'])
dl_train = create_dataloader(ds_dict['train'], cfg['batch_size'], cfg['train_samples_per_epoch'], True)
dl_val = create_dataloader(ds_dict['validation'], cfg['batch_size'], cfg['val_samples_per_epoch'], False)
dl_test = create_dataloader(ds_dict['test'], cfg['batch_size'], cfg['test_samples_per_epoch'], False)
# 4. Model Selection & Optimizer Setup
model_type = cfg.get('model_type', 'resnet18').lower()
print(f"[INFO] Initializing model architecture: {model_type}")
if model_type == 'resnet18':
model = make_resnet18(num_classes=cfg['num_classes'])
model = model.to(device)
# For ResNet transfer learning, we typically only optimize the head
opt = AdamW(model.fc.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
print("[INFO] Optimizer configured for ResNet head only.")
elif model_type == 'cnn':
model = PlantCNN(num_classes=cfg['num_classes'], p_drop=cfg.get('dropout', 0.5))
model = model.to(device)
# For custom CNN, we optimize all parameters
opt = AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
print("[INFO] Optimizer configured for full CNN parameters.")
else:
raise ValueError(f"Unknown model_type in config: {model_type}. Must be 'resnet18' or 'cnn'.")
# 5. Setup Loss & Stopper
crit = nn.CrossEntropyLoss()
stopper = EarlyStopping(patience=cfg['patience'], min_delta=cfg['min_delta'])
# 6. Loop
best_acc = 0.0
for epoch in range(1, cfg['epochs'] + 1):
train_loss, train_acc = train_one_epoch(model, dl_train, crit, opt, device)
val_loss, val_acc, val_top5 = evaluate(model, dl_val, crit, device, topk=5)
print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Top5: {val_top5:.3f}")
if logger:
logger.report_scalar("Loss", "train", train_loss, iteration=epoch)
logger.report_scalar("Accuracy", "train", train_acc, iteration=epoch)
logger.report_scalar("Loss", "val", val_loss, iteration=epoch)
logger.report_scalar("Accuracy", "val", val_acc, iteration=epoch)
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "checkpoints/best_baseline.pt")
if stopper.step(val_acc):
print("Early stopping.")
break
if logger:
print("[INFO] Uploading best model artifact to ClearML...")
task.upload_artifact(name="best_model", artifact_object="checkpoints/best_baseline.pt")
print("[SUCCESS] Model uploaded.")
if __name__ == "__main__":
main()
|