from src.model import Classifier from src.dataloader import ImageDataset,collate_fn from torch.utils.data import DataLoader import torch.optim as optim import torch.nn.functional as F from tqdm import tqdm import matplotlib.pyplot as plt import torch import random import numpy as np import torch.nn as nn import time def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) class ModelTrainer: def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None): g = torch.Generator() if seed is not None: g.manual_seed(seed) self.train_loader = DataLoader( train_set, batch_size, shuffle=True, collate_fn=collate_fn, worker_init_fn=seed_worker, generator=g ) self.device = device if val_set is not None: self.val_loader = DataLoader( val_set, batch_size, shuffle=False, collate_fn=collate_fn, worker_init_fn=seed_worker ) else: self.val_loader = None self.class_names = model.classes self.model = model self.lr = lr self.optim = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) self.optim.zero_grad() self.criterion = nn.CrossEntropyLoss() self.return_fig=return_fig def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4): first_image = imgs if isinstance(imgs, list): imgs = np.stack(imgs, axis=0) imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2).float() imgs_np = imgs.cpu().numpy() preds = preds.cpu().numpy() labels = labels.cpu().numpy() batch_size = imgs_np.shape[0] indices = random.sample(range(batch_size), min(max_samples, batch_size)) first_image = first_image[indices[0]] fig_pred = plt.figure(figsize=(6 * len(indices), 5)) grid = fig_pred.add_gridspec(1, len(indices)) for col, idx in enumerate(indices): ax = fig_pred.add_subplot(grid[0, col]) ax.imshow(imgs_np[idx].transpose(1, 2, 0)) if class_names: title = f"P: {class_names[preds[idx]]} | T: {class_names[labels[idx]]}" else: title = f"P: {preds[idx]} | T: {labels[idx]}" ax.set_title(title) ax.axis("off") fig_pred.tight_layout() raw_features = self.model.visualize_feature(first_image,show=False) feature_figs = [] for f in raw_features: if isinstance(f, plt.Figure): feature_figs.append(f) continue if hasattr(f, "mode"): f = np.array(f) h, w = f.shape[:2] dpi = 100 fig_w = max(4, w / dpi) fig_h = max(4, h / dpi) fig = plt.figure(figsize=(fig_w, fig_h), dpi=dpi) ax = fig.add_subplot(111) ax.imshow(f) ax.axis("off") feature_figs.append(fig) all_figs = [fig_pred] + feature_figs if not self.return_fig: plt.show() plt.close(fig_pred) if self.return_fig: return all_figs else: return None def train_one_epoch(self): self.model.train() total_loss = 0 train_pbar = tqdm(self.train_loader, desc="Training",leave=False) correct = 0 total = 0 for imgs, labels in train_pbar: labels = labels.to(self.device) # Forward outputs = self.model(imgs) loss = self.criterion(outputs, labels) # Backward self.optim.zero_grad() loss.backward() self.optim.step() preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) total_loss += loss.item() train_pbar.set_postfix(acc=correct/total,loss=loss.item()) avg_loss = total_loss / len(self.train_loader) avg_acc = correct / total return avg_loss,avg_acc def train(self, epochs=10, visualize_every=5): train_losses=[] train_accuracies=[] val_losses=[] val_accuracies=[] for epoch in range(1, epochs + 1): train_loss,train_acc = self.train_one_epoch() train_losses.append(train_loss) train_accuracies.append(train_acc) if self.val_loader is not None: val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1)) val_losses.append(val_loss) val_accuracies.append(val_acc) print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}") yield train_loss,train_acc,val_loss,val_acc,fig else: print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}") yield train_loss,train_acc,None,None,None yield train_losses,train_accuracies,val_losses,val_accuracies,None def validate(self,epoch, visualize=False): if self.val_loader is None: return self.model.eval() total_loss = 0 correct = 0 total = 0 val_imgs_display = None val_preds_display = None val_labels_display = None val_pbar = tqdm(self.val_loader, desc="Validation",leave=False) fig = None with torch.no_grad(): for imgs, labels in val_pbar: labels = labels.to(self.device) outputs = self.model(imgs) loss = self.criterion(outputs, labels) total_loss += loss.item() preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) if visualize and val_imgs_display is None: val_imgs_display = imgs val_preds_display = preds val_labels_display = labels val_pbar.set_postfix(loss=loss.item(), acc=correct / total) avg_loss = total_loss / len(self.val_loader) acc = correct / total if visualize and val_imgs_display is not None: fig = self.visualize_batch(val_imgs_display, val_preds_display, val_labels_display, self.class_names) return avg_loss,acc,fig