from src.model import Classifier from src.dataloader import ImageDataset,collate_fn from torch.utils.data import DataLoader import torch.optim as optim import torch.nn.functional as F from tqdm import tqdm import matplotlib.pyplot as plt import torch import random import numpy as np import torch.nn as nn import time from sklearn.metrics import ( confusion_matrix, classification_report, roc_curve, auc ) from sklearn.preprocessing import label_binarize def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) def model_evaluation(model, val_set, device,batch_size=32,num_workers=0, class_names=None): model.eval() all_preds = [] all_probs = [] all_labels = [] val_loader = DataLoader( val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers ) with torch.no_grad(): for images, labels in val_loader: if images.ndim == 4 and images.shape[-1] in (1, 3): images = images.permute(0, 3, 1, 2) images = images.to(device) labels = labels.to(device) logits = model(images) probs = torch.softmax(logits, dim=1) preds = torch.argmax(probs, dim=1) all_preds.append(preds.cpu().numpy()) all_probs.append(probs.cpu().numpy()) all_labels.append(labels.cpu().numpy()) y_true = np.concatenate(all_labels) y_pred = np.concatenate(all_preds) y_prob = np.concatenate(all_probs) num_classes = y_prob.shape[1] if class_names is None: class_names = [f"Class {i}" for i in range(num_classes)] cm = confusion_matrix(y_true, y_pred) cm_fig, ax = plt.subplots(figsize=(6, 6)) im = ax.imshow(cm) ax.set_title("Confusion Matrix") ax.set_xlabel("Predicted") ax.set_ylabel("True") ax.set_xticks(range(num_classes)) ax.set_yticks(range(num_classes)) ax.set_xticklabels(class_names, rotation=75) ax.set_yticklabels(class_names) for i in range(num_classes): for j in range(num_classes): ax.text(j, i, cm[i, j], ha="center", va="center") plt.tight_layout() report = classification_report( y_true, y_pred, target_names=class_names, output_dict=True ) cr_fig, ax = plt.subplots(figsize=(12, 8)) ax.axis("off") table_data = [] headers = ["Class", "Precision", "Recall", "F1", "Support"] for cls in class_names: row = report[cls] table_data.append([ cls, f"{row['precision']:.3f}", f"{row['recall']:.3f}", f"{row['f1-score']:.3f}", int(row['support']) ]) accuracy = report["accuracy"] macro_avg = report["macro avg"] weighted_avg = report["weighted avg"] table_data.append([ "Accuracy", f"{accuracy:.3f}", "", "", "" ]) table_data.append([ "Macro Avg", f"{macro_avg['precision']:.3f}", f"{macro_avg['recall']:.3f}", f"{macro_avg['f1-score']:.3f}", f"{int(macro_avg['support'])}" if 'support' in macro_avg else "" ]) table_data.append([ "Weighted Avg", f"{weighted_avg['precision']:.3f}", f"{weighted_avg['recall']:.3f}", f"{weighted_avg['f1-score']:.3f}", f"{int(weighted_avg['support'])}" if 'support' in weighted_avg else "" ]) table = ax.table( cellText=table_data, colLabels=headers, loc="center" ) table.scale(1, 2) ax.set_title("Classification Report") y_true_bin = label_binarize(y_true, classes=list(range(num_classes))) roc_fig, ax = plt.subplots(figsize=(6, 6)) for i in range(num_classes): fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i]) roc_auc = auc(fpr, tpr) ax.plot(fpr, tpr, label=f"{class_names[i]} (AUC={roc_auc:.3f})") ax.plot([0, 1], [0, 1], linestyle="--") ax.set_xlabel("False Positive Rate") ax.set_ylabel("True Positive Rate") ax.set_title("ROC-AUC Curve") ax.legend() ax.grid(True) return cm_fig, cr_fig, roc_fig class ModelTrainer: def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None): g = torch.Generator() if seed is not None: g.manual_seed(seed) self.train_loader = DataLoader( train_set, batch_size, shuffle=True, collate_fn=collate_fn, worker_init_fn=seed_worker, generator=g ) self.device = device if val_set is not None: self.val_loader = DataLoader( val_set, batch_size, shuffle=False, collate_fn=collate_fn, worker_init_fn=seed_worker ) else: self.val_loader = None self.class_names = model.classes self.model = model self.lr = lr self.optim = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) self.optim.zero_grad() self.criterion = nn.CrossEntropyLoss() self.return_fig=return_fig self.best_model_state = None self.best_val_acc = 0.0 self.interrupt=False def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4): first_image = imgs if isinstance(imgs, list): imgs = np.stack(imgs, axis=0) imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2).float() imgs_np = imgs.cpu().numpy() preds = preds.cpu().numpy() labels = labels.cpu().numpy() batch_size = imgs_np.shape[0] indices = random.sample(range(batch_size), min(max_samples, batch_size)) first_image = first_image[indices[0]] fig_pred = plt.figure(figsize=(6 * len(indices), 5)) grid = fig_pred.add_gridspec(1, len(indices)) for col, idx in enumerate(indices): ax = fig_pred.add_subplot(grid[0, col]) ax.imshow(imgs_np[idx].transpose(1, 2, 0)) if class_names: title = f"P: {class_names[preds[idx]]} | T: {class_names[labels[idx]]}" else: title = f"P: {preds[idx]} | T: {labels[idx]}" ax.set_title(title) ax.axis("off") fig_pred.tight_layout() raw_features = self.model.visualize_feature(first_image,show=False) feature_figs = [] for f in raw_features: if isinstance(f, plt.Figure): feature_figs.append(f) continue if hasattr(f, "mode"): f = np.array(f) h, w = f.shape[:2] dpi = 100 fig_w = max(4, w / dpi) fig_h = max(4, h / dpi) fig = plt.figure(figsize=(fig_w, fig_h), dpi=dpi) ax = fig.add_subplot(111) ax.imshow(f) ax.axis("off") feature_figs.append(fig) all_figs = [fig_pred] + feature_figs if not self.return_fig: plt.show() plt.close(fig_pred) if self.return_fig: return all_figs else: return None def train_one_epoch(self,epoch): self.model.train() total_loss = 0 train_pbar = tqdm(self.train_loader, desc="Training",leave=False) correct = 0 total = 0 for imgs, labels in train_pbar: if self.interrupt: break labels = labels.to(self.device) # Forward outputs = self.model(imgs) loss = self.criterion(outputs, labels) # Backward self.optim.zero_grad() loss.backward() self.optim.step() preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) total_loss += loss.item() train_pbar.set_postfix(acc=correct/total,loss=loss.item()) avg_loss = total_loss / len(self.train_loader) avg_acc = correct / total return avg_loss,avg_acc def train(self, epochs=10, visualize_every=5): train_losses=[] train_accuracies=[] val_losses=[] val_accuracies=[] for epoch in range(1, epochs + 1): train_loss,train_acc = self.train_one_epoch(epoch) if self.interrupt: return train_losses.append(train_loss) train_accuracies.append(train_acc) if self.val_loader is not None: val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1)) if self.interrupt: return val_losses.append(val_loss) val_accuracies.append(val_acc) print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}") if val_acc > self.best_val_acc: print(f"New best model found at epoch {epoch} (Val Acc: {val_acc:.4f})") self.best_val_acc = val_acc self.best_model_state = {k: v.clone() for k, v in self.model.state_dict().items()} yield train_loss,train_acc,val_loss,val_acc,fig else: print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}") yield train_loss,train_acc,None,None,None if self.best_model_state is not None: self.model.load_state_dict(self.best_model_state) print(f"Best model (Val Acc: {self.best_val_acc:.4f}) loaded into trainer.model") yield train_losses,train_accuracies,val_losses,val_accuracies,None def validate(self,epoch, visualize=False): if self.val_loader is None: return self.model.eval() total_loss = 0 correct = 0 total = 0 val_imgs_display = None val_preds_display = None val_labels_display = None val_pbar = tqdm(self.val_loader, desc="Validation",leave=False) fig = None with torch.no_grad(): for imgs, labels in val_pbar: if self.interrupt: break labels = labels.to(self.device) outputs = self.model(imgs) loss = self.criterion(outputs, labels) total_loss += loss.item() preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) if visualize and val_imgs_display is None: val_imgs_display = imgs val_preds_display = preds val_labels_display = labels val_pbar.set_postfix(loss=loss.item(), acc=correct / total) avg_loss = total_loss / len(self.val_loader) acc = correct / total if visualize and val_imgs_display is not None: fig = self.visualize_batch(val_imgs_display, val_preds_display, val_labels_display, self.class_names) return avg_loss,acc,fig