|
|
import torch |
|
|
import torchvision |
|
|
import torchvision.transforms as transforms |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
from mymodel import MyCIFAR10Net |
|
|
import torch.optim as optim |
|
|
import torch.nn as nn |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
]) |
|
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) |
|
|
train_size = int(0.8 * len(trainset)) |
|
|
valid_size = len(trainset) - train_size |
|
|
train_subset, valid_subset = random_split(trainset, [train_size, valid_size]) |
|
|
trainloader = DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=2) |
|
|
valloader = DataLoader(valid_subset, batch_size=128, shuffle=False, num_workers=2) |
|
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) |
|
|
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = MyCIFAR10Net(num_classes=10, use_batchnorm=True, use_dropout=True, activation='leakyrelu').to(device) |
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epoch_losses = [] |
|
|
|
|
|
def train(num_epochs=10): |
|
|
for epoch in range(num_epochs): |
|
|
model.train() |
|
|
running_loss = 0.0 |
|
|
for i, (inputs, labels) in enumerate(trainloader): |
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
optimizer.zero_grad() |
|
|
outputs = model(inputs) |
|
|
|
|
|
|
|
|
l2_lambda = 1e-4 |
|
|
l2_reg = torch.tensor(0., device=device) |
|
|
for param in model.parameters(): |
|
|
l2_reg += torch.norm(param, 2) |
|
|
loss = loss_fn(outputs, labels) + l2_lambda * l2_reg |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
running_loss += loss.item() |
|
|
avg_loss = running_loss/len(trainloader) |
|
|
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}") |
|
|
epoch_losses.append(avg_loss) |
|
|
validate() |
|
|
|
|
|
import time |
|
|
time_stamp = time.strftime("%Y%m%d-%H%M%S") |
|
|
path=f'model/best_model_{time_stamp}.pth' |
|
|
torch.save(model.state_dict(), path) |
|
|
|
|
|
plt.figure() |
|
|
plt.plot(range(1, len(epoch_losses)+1), epoch_losses, marker='o') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Training Loss per Epoch') |
|
|
plt.savefig('loss_curve.png') |
|
|
plt.close() |
|
|
|
|
|
def validate(): |
|
|
model.eval() |
|
|
correct = 0 |
|
|
total = 0 |
|
|
with torch.no_grad(): |
|
|
for inputs, labels in valloader: |
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
outputs = model(inputs) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
print(f'Validation Accuracy: {100 * correct / total:.2f}%') |
|
|
|
|
|
def save_all_conv_filters(model, filename='all_filters.png', layer='conv1'): |
|
|
if layer == 'conv1': |
|
|
filters = model.conv1.weight.data.clone().cpu() |
|
|
elif layer == 'conv2': |
|
|
filters = model.conv2.weight.data.clone().cpu() |
|
|
else: |
|
|
raise ValueError("layer must be 'conv1' or 'conv2'") |
|
|
num_filters = filters.shape[0] |
|
|
ncols = 8 |
|
|
nrows = (num_filters + ncols - 1) // ncols |
|
|
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows*2)) |
|
|
for i in range(num_filters): |
|
|
r, c = divmod(i, ncols) |
|
|
f = filters[i] |
|
|
f_min, f_max = f.min(), f.max() |
|
|
f = (f - f_min) / (f_max - f_min) |
|
|
if f.shape[0] == 3: |
|
|
axes[r, c].imshow(f.permute(1, 2, 0)) |
|
|
else: |
|
|
axes[r, c].imshow(f[0], cmap='gray') |
|
|
axes[r, c].axis('off') |
|
|
for i in range(num_filters, nrows * ncols): |
|
|
r, c = divmod(i, ncols) |
|
|
axes[r, c].axis('off') |
|
|
plt.tight_layout() |
|
|
plt.savefig(filename) |
|
|
plt.close() |
|
|
|
|
|
def visualize_all_feature_maps(model, image, filename='feature_maps.png', after='conv1'): |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
x = image.unsqueeze(0).to(next(model.parameters()).device) |
|
|
x = model.conv1(x) |
|
|
x = model.bn1(x) |
|
|
x = model._activate(x) |
|
|
if after == 'conv2': |
|
|
x = model.pool(x) |
|
|
x = model.conv2(x) |
|
|
x = model.bn2(x) |
|
|
x = model._activate(x) |
|
|
feature_maps = x.cpu().squeeze(0) |
|
|
num_maps = feature_maps.shape[0] |
|
|
ncols = 8 |
|
|
nrows = (num_maps + ncols - 1) // ncols |
|
|
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows*2)) |
|
|
for i in range(num_maps): |
|
|
r, c = divmod(i, ncols) |
|
|
fmap = feature_maps[i] |
|
|
fmap_min, fmap_max = fmap.min(), fmap.max() |
|
|
fmap = (fmap - fmap_min) / (fmap_max - fmap_min) |
|
|
axes[r, c].imshow(fmap, cmap='viridis') |
|
|
axes[r, c].axis('off') |
|
|
for i in range(num_maps, nrows * ncols): |
|
|
r, c = divmod(i, ncols) |
|
|
axes[r, c].axis('off') |
|
|
plt.tight_layout() |
|
|
plt.savefig(filename) |
|
|
plt.close() |
|
|
|
|
|
def plot_loss_landscape(model, dataloader, loss_fn, steps=20, alpha=0.5): |
|
|
w = model.fc1.weight.data.clone() |
|
|
direction1 = torch.randn_like(w) |
|
|
direction2 = torch.randn_like(w) |
|
|
losses = np.zeros((steps, steps)) |
|
|
device = next(model.parameters()).device |
|
|
for i, a in enumerate(np.linspace(-alpha, alpha, steps)): |
|
|
for j, b in enumerate(np.linspace(-alpha, alpha, steps)): |
|
|
model.fc1.weight.data = w + a * direction1 + b * direction2 |
|
|
total_loss = 0 |
|
|
count = 0 |
|
|
for inputs, labels in dataloader: |
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
outputs = model(inputs) |
|
|
loss = loss_fn(outputs, labels) |
|
|
total_loss += loss.item() |
|
|
count += 1 |
|
|
if count > 2: |
|
|
break |
|
|
losses[i, j] = total_loss / count |
|
|
model.fc1.weight.data = w |
|
|
plt.figure(figsize=(6,5)) |
|
|
plt.contourf(losses, levels=50, cmap='viridis') |
|
|
plt.colorbar() |
|
|
plt.title('Loss Landscape (fc1 weight)') |
|
|
plt.xlabel('Direction 1') |
|
|
plt.ylabel('Direction 2') |
|
|
plt.savefig('loss_landscape.png') |
|
|
plt.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train(num_epochs=10) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
correct = 0 |
|
|
total = 0 |
|
|
with torch.no_grad(): |
|
|
for inputs, labels in testloader: |
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
outputs = model(inputs) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
|
|
|
print(f'Test Accuracy: {100 * correct / total:.2f}%') |
|
|
print(f'Test Error: {100 - 100 * correct / total:.2f}%') |