unfair221's picture
Upload folder using huggingface_hub
0bae2fa verified
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from mymodel import MyCIFAR10Net
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
# 1. Data loading
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(trainset))
valid_size = len(trainset) - train_size
train_subset, valid_subset = random_split(trainset, [train_size, valid_size])
trainloader = DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(valid_subset, batch_size=128, shuffle=False, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
# 2. Model, loss, optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyCIFAR10Net(num_classes=10, use_batchnorm=True, use_dropout=True, activation='leakyrelu').to(device)
loss_fn = nn.CrossEntropyLoss() # Try different loss functions here
optimizer = optim.Adam(model.parameters(), lr=0.001) # Try different optimizers here
# 切换优化器
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
# optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=1e-4)
epoch_losses = []
def train(num_epochs=10):
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# loss = loss_fn(outputs, labels)
l2_lambda = 1e-4 # L2正则化强度
l2_reg = torch.tensor(0., device=device)
for param in model.parameters():
l2_reg += torch.norm(param, 2)
loss = loss_fn(outputs, labels) + l2_lambda * l2_reg
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss/len(trainloader)
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
epoch_losses.append(avg_loss)
validate()
# Save best model if needed
import time
time_stamp = time.strftime("%Y%m%d-%H%M%S")
path=f'model/best_model_{time_stamp}.pth'
torch.save(model.state_dict(), path)
# 保存loss曲线
plt.figure()
plt.plot(range(1, len(epoch_losses)+1), epoch_losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss per Epoch')
plt.savefig('loss_curve.png')
plt.close()
def validate():
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in valloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Validation Accuracy: {100 * correct / total:.2f}%')
def save_all_conv_filters(model, filename='all_filters.png', layer='conv1'):
if layer == 'conv1':
filters = model.conv1.weight.data.clone().cpu()
elif layer == 'conv2':
filters = model.conv2.weight.data.clone().cpu()
else:
raise ValueError("layer must be 'conv1' or 'conv2'")
num_filters = filters.shape[0]
ncols = 8
nrows = (num_filters + ncols - 1) // ncols
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows*2))
for i in range(num_filters):
r, c = divmod(i, ncols)
f = filters[i]
f_min, f_max = f.min(), f.max()
f = (f - f_min) / (f_max - f_min)
if f.shape[0] == 3: # RGB
axes[r, c].imshow(f.permute(1, 2, 0))
else: # 单通道
axes[r, c].imshow(f[0], cmap='gray')
axes[r, c].axis('off')
for i in range(num_filters, nrows * ncols):
r, c = divmod(i, ncols)
axes[r, c].axis('off')
plt.tight_layout()
plt.savefig(filename)
plt.close()
def visualize_all_feature_maps(model, image, filename='feature_maps.png', after='conv1'):
model.eval()
with torch.no_grad():
x = image.unsqueeze(0).to(next(model.parameters()).device)
x = model.conv1(x)
x = model.bn1(x)
x = model._activate(x)
if after == 'conv2':
x = model.pool(x)
x = model.conv2(x)
x = model.bn2(x)
x = model._activate(x)
feature_maps = x.cpu().squeeze(0)
num_maps = feature_maps.shape[0]
ncols = 8
nrows = (num_maps + ncols - 1) // ncols
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows*2))
for i in range(num_maps):
r, c = divmod(i, ncols)
fmap = feature_maps[i]
fmap_min, fmap_max = fmap.min(), fmap.max()
fmap = (fmap - fmap_min) / (fmap_max - fmap_min)
axes[r, c].imshow(fmap, cmap='viridis')
axes[r, c].axis('off')
for i in range(num_maps, nrows * ncols):
r, c = divmod(i, ncols)
axes[r, c].axis('off')
plt.tight_layout()
plt.savefig(filename)
plt.close()
def plot_loss_landscape(model, dataloader, loss_fn, steps=20, alpha=0.5):
w = model.fc1.weight.data.clone()
direction1 = torch.randn_like(w)
direction2 = torch.randn_like(w)
losses = np.zeros((steps, steps))
device = next(model.parameters()).device
for i, a in enumerate(np.linspace(-alpha, alpha, steps)):
for j, b in enumerate(np.linspace(-alpha, alpha, steps)):
model.fc1.weight.data = w + a * direction1 + b * direction2
total_loss = 0
count = 0
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = loss_fn(outputs, labels)
total_loss += loss.item()
count += 1
if count > 2: # 只用少量batch加速
break
losses[i, j] = total_loss / count
model.fc1.weight.data = w # 恢复原权重
plt.figure(figsize=(6,5))
plt.contourf(losses, levels=50, cmap='viridis')
plt.colorbar()
plt.title('Loss Landscape (fc1 weight)')
plt.xlabel('Direction 1')
plt.ylabel('Direction 2')
plt.savefig('loss_landscape.png')
plt.close()
if __name__ == "__main__":
train(num_epochs=10)
# model.eval()
# save_conv1_filters(model)
# 加载已有模型
# model.load_state_dict(torch.load('best_model.pth', map_location=device))
# model.eval()
# # 保存所有卷积核
# save_all_conv_filters(model, filename='all_filters_conv1.png', layer='conv1')
# save_all_conv_filters(model, filename='all_filters_conv2.png', layer='conv2')
# # 取一张验证集图片
# sample_img, _ = next(iter(valloader))
# # 可视化feature map
# visualize_all_feature_maps(model, sample_img[0], filename='feature_maps_conv1.png', after='conv1')
# visualize_all_feature_maps(model, sample_img[0], filename='feature_maps_conv2.png', after='conv2')
# # Loss landscape visualization
# plot_loss_landscape(model, valloader, loss_fn)
# Evaluate on test set after training
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
print(f'Test Error: {100 - 100 * correct / total:.2f}%')