VJyzCELERY's picture
Commit to hf space
c3d45c0
raw
history blame
6.86 kB
from src.model import Classifier
from src.dataloader import ImageDataset,collate_fn
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import random
import numpy as np
import torch.nn as nn
import time
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
class ModelTrainer:
def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None):
g = torch.Generator()
if seed is not None:
g.manual_seed(seed)
self.train_loader = DataLoader(
train_set,
batch_size,
shuffle=True,
collate_fn=collate_fn,
worker_init_fn=seed_worker,
generator=g
)
self.device = device
if val_set is not None:
self.val_loader = DataLoader(
val_set,
batch_size,
shuffle=False,
collate_fn=collate_fn,
worker_init_fn=seed_worker
)
else:
self.val_loader = None
self.class_names = model.classes
self.model = model
self.lr = lr
self.optim = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
self.optim.zero_grad()
self.criterion = nn.CrossEntropyLoss()
self.return_fig=return_fig
def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4):
first_image = imgs
if isinstance(imgs, list):
imgs = np.stack(imgs, axis=0)
imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2).float()
imgs_np = imgs.cpu().numpy()
preds = preds.cpu().numpy()
labels = labels.cpu().numpy()
batch_size = imgs_np.shape[0]
indices = random.sample(range(batch_size), min(max_samples, batch_size))
first_image = first_image[indices[0]]
fig_pred = plt.figure(figsize=(6 * len(indices), 5))
grid = fig_pred.add_gridspec(1, len(indices))
for col, idx in enumerate(indices):
ax = fig_pred.add_subplot(grid[0, col])
ax.imshow(imgs_np[idx].transpose(1, 2, 0))
if class_names:
title = f"P: {class_names[preds[idx]]} | T: {class_names[labels[idx]]}"
else:
title = f"P: {preds[idx]} | T: {labels[idx]}"
ax.set_title(title)
ax.axis("off")
fig_pred.tight_layout()
raw_features = self.model.visualize_feature(first_image,show=False)
feature_figs = []
for f in raw_features:
if isinstance(f, plt.Figure):
feature_figs.append(f)
continue
if hasattr(f, "mode"):
f = np.array(f)
h, w = f.shape[:2]
dpi = 100
fig_w = max(4, w / dpi)
fig_h = max(4, h / dpi)
fig = plt.figure(figsize=(fig_w, fig_h), dpi=dpi)
ax = fig.add_subplot(111)
ax.imshow(f)
ax.axis("off")
feature_figs.append(fig)
all_figs = [fig_pred] + feature_figs
if not self.return_fig:
plt.show()
plt.close(fig_pred)
if self.return_fig:
return all_figs
else:
return None
def train_one_epoch(self):
self.model.train()
total_loss = 0
train_pbar = tqdm(self.train_loader, desc="Training",leave=False)
correct = 0
total = 0
for imgs, labels in train_pbar:
labels = labels.to(self.device)
# Forward
outputs = self.model(imgs)
loss = self.criterion(outputs, labels)
# Backward
self.optim.zero_grad()
loss.backward()
self.optim.step()
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
total_loss += loss.item()
train_pbar.set_postfix(acc=correct/total,loss=loss.item())
avg_loss = total_loss / len(self.train_loader)
avg_acc = correct / total
return avg_loss,avg_acc
def train(self, epochs=10, visualize_every=5):
train_losses=[]
train_accuracies=[]
val_losses=[]
val_accuracies=[]
for epoch in range(1, epochs + 1):
train_loss,train_acc = self.train_one_epoch()
train_losses.append(train_loss)
train_accuracies.append(train_acc)
if self.val_loader is not None:
val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1))
val_losses.append(val_loss)
val_accuracies.append(val_acc)
print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}")
yield train_loss,train_acc,val_loss,val_acc,fig
else:
print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}")
yield train_loss,train_acc,None,None,None
yield train_losses,train_accuracies,val_losses,val_accuracies,None
def validate(self,epoch, visualize=False):
if self.val_loader is None:
return
self.model.eval()
total_loss = 0
correct = 0
total = 0
val_imgs_display = None
val_preds_display = None
val_labels_display = None
val_pbar = tqdm(self.val_loader, desc="Validation",leave=False)
fig = None
with torch.no_grad():
for imgs, labels in val_pbar:
labels = labels.to(self.device)
outputs = self.model(imgs)
loss = self.criterion(outputs, labels)
total_loss += loss.item()
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
if visualize and val_imgs_display is None:
val_imgs_display = imgs
val_preds_display = preds
val_labels_display = labels
val_pbar.set_postfix(loss=loss.item(), acc=correct / total)
avg_loss = total_loss / len(self.val_loader)
acc = correct / total
if visualize and val_imgs_display is not None:
fig = self.visualize_batch(val_imgs_display, val_preds_display, val_labels_display, self.class_names)
return avg_loss,acc,fig