import os import time import torch import torch.nn as nn from torch.optim import AdamW from datasets import load_from_disk import subprocess import sys # Import models from src.models.resnet18_finetune import make_resnet18 from src.models.cnn_model import PlantCNN # Import utils from src.utils.config import load_config from src.utils.metrics import accuracy, topk_accuracy from src.train.early_stopping import EarlyStopping # Import Dataloader from src.DataLoader.dataloader import create_dataloader def train_one_epoch(model, loader, criterion, opt, device): model.train() total_loss, total_correct, total_samples = 0.0, 0, 0 for inputs, labels in loader: inputs = inputs.to(device) # Loader might return one-hot labels. CrossEntropyLoss needs indices. if labels.ndim > 1: labels = labels.argmax(dim=1) labels = labels.to(device).long() opt.zero_grad(set_to_none=True) logits = model(inputs) loss = criterion(logits, labels) loss.backward() opt.step() batch_size = inputs.size(0) total_loss += loss.item() * batch_size total_correct += (logits.argmax(1) == labels).sum().item() total_samples += batch_size return total_loss / total_samples, total_correct / total_samples @torch.no_grad() def evaluate(model, loader, criterion, device, topk=5): model.eval() total_loss, total_correct, total_topk, total_samples = 0.0, 0, 0, 0 for inputs, labels in loader: inputs = inputs.to(device) if labels.ndim > 1: labels = labels.argmax(dim=1) labels = labels.to(device).long() logits = model(inputs) loss = criterion(logits, labels) batch_size = inputs.size(0) total_loss += loss.item() * batch_size total_correct += (logits.argmax(1) == labels).sum().item() # Top-k topk_preds = logits.topk(topk, dim=1).indices total_topk += (topk_preds == labels.unsqueeze(1)).any(dim=1).sum().item() total_samples += batch_size return total_loss / total_samples, total_correct / total_samples, total_topk / total_samples def main(): print("[INFO] Starting Integration Training Pipeline") # 1. Config cfg = load_config() os.makedirs("checkpoints", exist_ok=True) # 2. ClearML try: from clearml import Task task = Task.init(project_name=cfg.get("project", "PlantDisease"), task_name=cfg.get("task_name", "model_training")) task.set_packages("./requirements.txt") task.execute_remotely(queue_name="default") task.connect(cfg) logger = task.get_logger() print("[INFO] ClearML Initialized") except ImportError: logger = None print("[INFO] ClearML not found, skipping logging") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[INFO] Device: {device}") data_path = cfg['data_path'] if not os.path.exists(data_path): print(f"[WARN] Data path '{data_path}' not found.") print("[INFO] Attempting to run data processing script...") try: subprocess.check_call([sys.executable, "process_dataset.py"]) print("[SUCCESS] Data processing complete.") except subprocess.CalledProcessError as e: print(f"[FATAL] Data processing failed: {e}") exit(1) # 3. Data print(f"[INFO] Loading data from {cfg['data_path']}") ds_dict = load_from_disk(cfg['data_path']) dl_train = create_dataloader(ds_dict['train'], cfg['batch_size'], cfg['train_samples_per_epoch'], True) dl_val = create_dataloader(ds_dict['validation'], cfg['batch_size'], cfg['val_samples_per_epoch'], False) dl_test = create_dataloader(ds_dict['test'], cfg['batch_size'], cfg['test_samples_per_epoch'], False) # 4. Model Selection & Optimizer Setup model_type = cfg.get('model_type', 'resnet18').lower() print(f"[INFO] Initializing model architecture: {model_type}") if model_type == 'resnet18': model = make_resnet18(num_classes=cfg['num_classes']) model = model.to(device) # For ResNet transfer learning, we typically only optimize the head opt = AdamW(model.fc.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) print("[INFO] Optimizer configured for ResNet head only.") elif model_type == 'cnn': model = PlantCNN(num_classes=cfg['num_classes'], p_drop=cfg.get('dropout', 0.5)) model = model.to(device) # For custom CNN, we optimize all parameters opt = AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) print("[INFO] Optimizer configured for full CNN parameters.") else: raise ValueError(f"Unknown model_type in config: {model_type}. Must be 'resnet18' or 'cnn'.") # 5. Setup Loss & Stopper crit = nn.CrossEntropyLoss() stopper = EarlyStopping(patience=cfg['patience'], min_delta=cfg['min_delta']) # 6. Loop best_acc = 0.0 for epoch in range(1, cfg['epochs'] + 1): train_loss, train_acc = train_one_epoch(model, dl_train, crit, opt, device) val_loss, val_acc, val_top5 = evaluate(model, dl_val, crit, device, topk=5) print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Top5: {val_top5:.3f}") if logger: logger.report_scalar("Loss", "train", train_loss, iteration=epoch) logger.report_scalar("Accuracy", "train", train_acc, iteration=epoch) logger.report_scalar("Loss", "val", val_loss, iteration=epoch) logger.report_scalar("Accuracy", "val", val_acc, iteration=epoch) if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), "checkpoints/best_baseline.pt") if stopper.step(val_acc): print("Early stopping.") break if logger: print("[INFO] Uploading best model artifact to ClearML...") task.upload_artifact(name="best_model", artifact_object="checkpoints/best_baseline.pt") print("[SUCCESS] Model uploaded.") if __name__ == "__main__": main()