JAMM032's picture
Upload github repo files
97fcc90 verified
import os
import time
import torch
import torch.nn as nn
from torch.optim import AdamW
from datasets import load_from_disk
import subprocess
import sys
# Import models
from src.models.resnet18_finetune import make_resnet18
from src.models.cnn_model import PlantCNN
# Import utils
from src.utils.config import load_config
from src.utils.metrics import accuracy, topk_accuracy
from src.train.early_stopping import EarlyStopping
# Import Dataloader
from src.DataLoader.dataloader import create_dataloader
def train_one_epoch(model, loader, criterion, opt, device):
model.train()
total_loss, total_correct, total_samples = 0.0, 0, 0
for inputs, labels in loader:
inputs = inputs.to(device)
# Loader might return one-hot labels. CrossEntropyLoss needs indices.
if labels.ndim > 1:
labels = labels.argmax(dim=1)
labels = labels.to(device).long()
opt.zero_grad(set_to_none=True)
logits = model(inputs)
loss = criterion(logits, labels)
loss.backward()
opt.step()
batch_size = inputs.size(0)
total_loss += loss.item() * batch_size
total_correct += (logits.argmax(1) == labels).sum().item()
total_samples += batch_size
return total_loss / total_samples, total_correct / total_samples
@torch.no_grad()
def evaluate(model, loader, criterion, device, topk=5):
model.eval()
total_loss, total_correct, total_topk, total_samples = 0.0, 0, 0, 0
for inputs, labels in loader:
inputs = inputs.to(device)
if labels.ndim > 1:
labels = labels.argmax(dim=1)
labels = labels.to(device).long()
logits = model(inputs)
loss = criterion(logits, labels)
batch_size = inputs.size(0)
total_loss += loss.item() * batch_size
total_correct += (logits.argmax(1) == labels).sum().item()
# Top-k
topk_preds = logits.topk(topk, dim=1).indices
total_topk += (topk_preds == labels.unsqueeze(1)).any(dim=1).sum().item()
total_samples += batch_size
return total_loss / total_samples, total_correct / total_samples, total_topk / total_samples
def main():
print("[INFO] Starting Integration Training Pipeline")
# 1. Config
cfg = load_config()
os.makedirs("checkpoints", exist_ok=True)
# 2. ClearML
try:
from clearml import Task
task = Task.init(project_name=cfg.get("project", "PlantDisease"), task_name=cfg.get("task_name", "model_training"))
task.set_packages("./requirements.txt")
task.execute_remotely(queue_name="default")
task.connect(cfg)
logger = task.get_logger()
print("[INFO] ClearML Initialized")
except ImportError:
logger = None
print("[INFO] ClearML not found, skipping logging")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Device: {device}")
data_path = cfg['data_path']
if not os.path.exists(data_path):
print(f"[WARN] Data path '{data_path}' not found.")
print("[INFO] Attempting to run data processing script...")
try:
subprocess.check_call([sys.executable, "process_dataset.py"])
print("[SUCCESS] Data processing complete.")
except subprocess.CalledProcessError as e:
print(f"[FATAL] Data processing failed: {e}")
exit(1)
# 3. Data
print(f"[INFO] Loading data from {cfg['data_path']}")
ds_dict = load_from_disk(cfg['data_path'])
dl_train = create_dataloader(ds_dict['train'], cfg['batch_size'], cfg['train_samples_per_epoch'], True)
dl_val = create_dataloader(ds_dict['validation'], cfg['batch_size'], cfg['val_samples_per_epoch'], False)
dl_test = create_dataloader(ds_dict['test'], cfg['batch_size'], cfg['test_samples_per_epoch'], False)
# 4. Model Selection & Optimizer Setup
model_type = cfg.get('model_type', 'resnet18').lower()
print(f"[INFO] Initializing model architecture: {model_type}")
if model_type == 'resnet18':
model = make_resnet18(num_classes=cfg['num_classes'])
model = model.to(device)
# For ResNet transfer learning, we typically only optimize the head
opt = AdamW(model.fc.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
print("[INFO] Optimizer configured for ResNet head only.")
elif model_type == 'cnn':
model = PlantCNN(num_classes=cfg['num_classes'], p_drop=cfg.get('dropout', 0.5))
model = model.to(device)
# For custom CNN, we optimize all parameters
opt = AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
print("[INFO] Optimizer configured for full CNN parameters.")
else:
raise ValueError(f"Unknown model_type in config: {model_type}. Must be 'resnet18' or 'cnn'.")
# 5. Setup Loss & Stopper
crit = nn.CrossEntropyLoss()
stopper = EarlyStopping(patience=cfg['patience'], min_delta=cfg['min_delta'])
# 6. Loop
best_acc = 0.0
for epoch in range(1, cfg['epochs'] + 1):
train_loss, train_acc = train_one_epoch(model, dl_train, crit, opt, device)
val_loss, val_acc, val_top5 = evaluate(model, dl_val, crit, device, topk=5)
print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} Top5: {val_top5:.3f}")
if logger:
logger.report_scalar("Loss", "train", train_loss, iteration=epoch)
logger.report_scalar("Accuracy", "train", train_acc, iteration=epoch)
logger.report_scalar("Loss", "val", val_loss, iteration=epoch)
logger.report_scalar("Accuracy", "val", val_acc, iteration=epoch)
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "checkpoints/best_baseline.pt")
if stopper.step(val_acc):
print("Early stopping.")
break
if logger:
print("[INFO] Uploading best model artifact to ClearML...")
task.upload_artifact(name="best_model", artifact_object="checkpoints/best_baseline.pt")
print("[SUCCESS] Model uploaded.")
if __name__ == "__main__":
main()