Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from datasets import load_from_disk | |
| import subprocess | |
| import sys | |
| # Import models | |
| from src.models.resnet18_finetune import make_resnet18 | |
| from src.models.cnn_model import PlantCNN | |
| # Import utils | |
| from src.utils.config import load_config | |
| from src.utils.metrics import accuracy, topk_accuracy | |
| from src.train.early_stopping import EarlyStopping | |
| # Import Dataloader | |
| from src.DataLoader.dataloader import create_dataloader | |
| def train_one_epoch(model, loader, criterion, opt, device): | |
| model.train() | |
| total_loss, total_correct, total_samples = 0.0, 0, 0 | |
| for inputs, labels in loader: | |
| inputs = inputs.to(device) | |
| # Loader might return one-hot labels. CrossEntropyLoss needs indices. | |
| if labels.ndim > 1: | |
| labels = labels.argmax(dim=1) | |
| labels = labels.to(device).long() | |
| opt.zero_grad(set_to_none=True) | |
| logits = model(inputs) | |
| loss = criterion(logits, labels) | |
| loss.backward() | |
| opt.step() | |
| batch_size = inputs.size(0) | |
| total_loss += loss.item() * batch_size | |
| total_correct += (logits.argmax(1) == labels).sum().item() | |
| total_samples += batch_size | |
| return total_loss / total_samples, total_correct / total_samples | |
| def evaluate(model, loader, criterion, device, topk=5): | |
| model.eval() | |
| total_loss, total_correct, total_topk, total_samples = 0.0, 0, 0, 0 | |
| for inputs, labels in loader: | |
| inputs = inputs.to(device) | |
| if labels.ndim > 1: | |
| labels = labels.argmax(dim=1) | |
| labels = labels.to(device).long() | |
| logits = model(inputs) | |
| loss = criterion(logits, labels) | |
| batch_size = inputs.size(0) | |
| total_loss += loss.item() * batch_size | |
| total_correct += (logits.argmax(1) == labels).sum().item() | |
| # Top-k | |
| topk_preds = logits.topk(topk, dim=1).indices | |
| total_topk += (topk_preds == labels.unsqueeze(1)).any(dim=1).sum().item() | |
| total_samples += batch_size | |
| return total_loss / total_samples, total_correct / total_samples, total_topk / total_samples | |
| def main(): | |
| print("[INFO] Starting Integration Training Pipeline") | |
| # 1. Config | |
| cfg = load_config() | |
| os.makedirs("checkpoints", exist_ok=True) | |
| # 2. ClearML | |
| try: | |
| from clearml import Task | |
| task = Task.init(project_name=cfg.get("project", "PlantDisease"), task_name=cfg.get("task_name", "model_training")) | |
| task.set_packages("./requirements.txt") | |
| task.execute_remotely(queue_name="default") | |
| task.connect(cfg) | |
| logger = task.get_logger() | |
| print("[INFO] ClearML Initialized") | |
| except ImportError: | |
| logger = None | |
| print("[INFO] ClearML not found, skipping logging") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[INFO] Device: {device}") | |
| data_path = cfg['data_path'] | |
| if not os.path.exists(data_path): | |
| print(f"[WARN] Data path '{data_path}' not found.") | |
| print("[INFO] Attempting to run data processing script...") | |
| try: | |
| subprocess.check_call([sys.executable, "process_dataset.py"]) | |
| print("[SUCCESS] Data processing complete.") | |
| except subprocess.CalledProcessError as e: | |
| print(f"[FATAL] Data processing failed: {e}") | |
| exit(1) | |
| # 3. Data | |
| print(f"[INFO] Loading data from {cfg['data_path']}") | |
| ds_dict = load_from_disk(cfg['data_path']) | |
| dl_train = create_dataloader(ds_dict['train'], cfg['batch_size'], cfg['train_samples_per_epoch'], True) | |
| dl_val = create_dataloader(ds_dict['validation'], cfg['batch_size'], cfg['val_samples_per_epoch'], False) | |
| dl_test = create_dataloader(ds_dict['test'], cfg['batch_size'], cfg['test_samples_per_epoch'], False) | |
| # 4. Model Selection & Optimizer Setup | |
| model_type = cfg.get('model_type', 'resnet18').lower() | |
| print(f"[INFO] Initializing model architecture: {model_type}") | |
| if model_type == 'resnet18': | |
| model = make_resnet18(num_classes=cfg['num_classes']) | |
| model = model.to(device) | |
| # For ResNet transfer learning, we typically only optimize the head | |
| opt = AdamW(model.fc.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) | |
| print("[INFO] Optimizer configured for ResNet head only.") | |
| elif model_type == 'cnn': | |
| model = PlantCNN(num_classes=cfg['num_classes'], p_drop=cfg.get('dropout', 0.5)) | |
| model = model.to(device) | |
| # For custom CNN, we optimize all parameters | |
| opt = AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) | |
| print("[INFO] Optimizer configured for full CNN parameters.") | |
| else: | |
| raise ValueError(f"Unknown model_type in config: {model_type}. Must be 'resnet18' or 'cnn'.") | |
| # 5. Setup Loss & Stopper | |
| crit = nn.CrossEntropyLoss() | |
| stopper = EarlyStopping(patience=cfg['patience'], min_delta=cfg['min_delta']) | |
| # 6. Loop | |
| best_acc = 0.0 | |
| for epoch in range(1, cfg['epochs'] + 1): | |
| train_loss, train_acc = train_one_epoch(model, dl_train, crit, opt, device) | |
| val_loss, val_acc, val_top5 = evaluate(model, dl_val, crit, device, topk=5) | |
| print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Top5: {val_top5:.3f}") | |
| if logger: | |
| logger.report_scalar("Loss", "train", train_loss, iteration=epoch) | |
| logger.report_scalar("Accuracy", "train", train_acc, iteration=epoch) | |
| logger.report_scalar("Loss", "val", val_loss, iteration=epoch) | |
| logger.report_scalar("Accuracy", "val", val_acc, iteration=epoch) | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| torch.save(model.state_dict(), "checkpoints/best_baseline.pt") | |
| if stopper.step(val_acc): | |
| print("Early stopping.") | |
| break | |
| if logger: | |
| print("[INFO] Uploading best model artifact to ClearML...") | |
| task.upload_artifact(name="best_model", artifact_object="checkpoints/best_baseline.pt") | |
| print("[SUCCESS] Model uploaded.") | |
| if __name__ == "__main__": | |
| main() | |