Initial commit of transfer learning project files
Browse files- config.yaml +42 -0
- models/ResNet50_ImageNet_224px_best.pth +3 -0
- src/__init__.py +0 -0
- src/dataset.py +47 -0
- src/engine.py +79 -0
- src/model.py +97 -0
- src/path.py +23 -0
- src/utils.py +25 -0
- start_sweep.sh +46 -0
- sweep.yaml +82 -0
- train.py +229 -0
config.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#网页设置
|
| 2 |
+
wandb_setup:
|
| 3 |
+
project: "cifar10_transfer_learning"
|
| 4 |
+
experiment: "ResNet50_ImageNet_224px"
|
| 5 |
+
tags: ["cifar10","resnet50"]
|
| 6 |
+
seed: 42
|
| 7 |
+
job_type: "train"
|
| 8 |
+
|
| 9 |
+
#数据加载参数
|
| 10 |
+
data:
|
| 11 |
+
data_path: "./data"
|
| 12 |
+
batch_size: 64
|
| 13 |
+
num_workers: 4
|
| 14 |
+
image_size: [224,224]
|
| 15 |
+
in_channels: 3
|
| 16 |
+
|
| 17 |
+
#模型结构参数
|
| 18 |
+
model:
|
| 19 |
+
type: "TransferResNet50"
|
| 20 |
+
dropout_rate: 0.0
|
| 21 |
+
num_classes: 10
|
| 22 |
+
|
| 23 |
+
#训练超参数
|
| 24 |
+
train:
|
| 25 |
+
epochs: 30
|
| 26 |
+
save_dir: "./models"
|
| 27 |
+
|
| 28 |
+
#优化器与调度器
|
| 29 |
+
optimizer:
|
| 30 |
+
name: "adamw"
|
| 31 |
+
lr: 0.001
|
| 32 |
+
backbone_lr: 0.00005
|
| 33 |
+
weight_decay: 1e-3
|
| 34 |
+
|
| 35 |
+
scheduler:
|
| 36 |
+
use_scheduler: True
|
| 37 |
+
type: "CosineAnnealingLR"
|
| 38 |
+
T_max: 30
|
| 39 |
+
eta_min: 1e-6
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
models/ResNet50_ImageNet_224px_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca15171214c8a6d75fca585e6ce2b16ba4b6b735bbba58d92f729936a4b16a02
|
| 3 |
+
size 94445757
|
src/__init__.py
ADDED
|
File without changes
|
src/dataset.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from torchvision import datasets,transforms
|
| 4 |
+
from src.path import DATA_DIR
|
| 5 |
+
|
| 6 |
+
def get_dataloader(config):
|
| 7 |
+
batch_size = config.get('batch_size',64)
|
| 8 |
+
data_path = config.get('data_path',DATA_DIR)
|
| 9 |
+
num_workers = config.get('num_workers',4)
|
| 10 |
+
|
| 11 |
+
mean = [0.485, 0.456, 0.406]
|
| 12 |
+
std = [0.229, 0.224, 0.225]
|
| 13 |
+
|
| 14 |
+
train_transform = transforms.Compose([
|
| 15 |
+
transforms.Resize((224,224)),
|
| 16 |
+
transforms.RandomHorizontalFlip(),
|
| 17 |
+
transforms.ToTensor(),
|
| 18 |
+
transforms.Normalize(mean,std),
|
| 19 |
+
])
|
| 20 |
+
val_transform = transforms.Compose([
|
| 21 |
+
transforms.Resize((224,224)),
|
| 22 |
+
transforms.ToTensor(),
|
| 23 |
+
transforms.Normalize(mean,std),
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
train_data = datasets.CIFAR10(root=data_path,train=True,download=True,transform=train_transform)
|
| 27 |
+
test_data = datasets.CIFAR10(root=data_path,train=False,download=True,transform=val_transform)
|
| 28 |
+
|
| 29 |
+
train_loader = DataLoader(
|
| 30 |
+
train_data,
|
| 31 |
+
batch_size=batch_size,
|
| 32 |
+
num_workers=num_workers,
|
| 33 |
+
shuffle=True,
|
| 34 |
+
pin_memory=True,
|
| 35 |
+
persistent_workers=True,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
test_loader = DataLoader(
|
| 39 |
+
test_data,
|
| 40 |
+
batch_size=batch_size,
|
| 41 |
+
num_workers=num_workers,
|
| 42 |
+
shuffle=False,
|
| 43 |
+
pin_memory=True,
|
| 44 |
+
persistent_workers=True,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return train_loader,test_loader
|
src/engine.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import wandb
|
| 4 |
+
from torch.amp import autocast, GradScaler
|
| 5 |
+
|
| 6 |
+
def train_one_epoch(epoch_id,model,data_loader,loss_fn,optimizer,device,scaler):
|
| 7 |
+
model.train()
|
| 8 |
+
training_loss = 0.0
|
| 9 |
+
running_correct = 0
|
| 10 |
+
total_samples = 0
|
| 11 |
+
|
| 12 |
+
for batch,(X,y) in enumerate(data_loader):
|
| 13 |
+
if not X.is_cuda:
|
| 14 |
+
X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)
|
| 15 |
+
X = X.to(memory_format=torch.channels_last)
|
| 16 |
+
optimizer.zero_grad(set_to_none=True)
|
| 17 |
+
|
| 18 |
+
with autocast('cuda',dtype=torch.float16):
|
| 19 |
+
pred = model(X)
|
| 20 |
+
loss = loss_fn(pred,y)
|
| 21 |
+
|
| 22 |
+
scaler.scale(loss).backward()
|
| 23 |
+
scaler.step(optimizer)
|
| 24 |
+
scaler.update()
|
| 25 |
+
|
| 26 |
+
pred_ids = pred.argmax(1)
|
| 27 |
+
running_correct += (pred_ids == y).type(torch.int).sum().item()
|
| 28 |
+
total_samples += y.size(0)
|
| 29 |
+
|
| 30 |
+
training_loss += loss.item()
|
| 31 |
+
|
| 32 |
+
train_epoch_loss = training_loss / len(data_loader)
|
| 33 |
+
train_epoch_acc = running_correct / total_samples
|
| 34 |
+
|
| 35 |
+
return train_epoch_loss,train_epoch_acc
|
| 36 |
+
|
| 37 |
+
def evaluate(epoch_id,model,data_loader,loss_fn,device):
|
| 38 |
+
model.eval()
|
| 39 |
+
testing_loss = 0.0
|
| 40 |
+
testing_correct = 0
|
| 41 |
+
total_samples = 0
|
| 42 |
+
bad_cases = []
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
for X,y in data_loader:
|
| 46 |
+
if not X.is_cuda:
|
| 47 |
+
X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)
|
| 48 |
+
|
| 49 |
+
pred = model(X)
|
| 50 |
+
loss = loss_fn(pred,y)
|
| 51 |
+
testing_loss += loss.item()
|
| 52 |
+
|
| 53 |
+
pred_ids = pred.argmax(1)
|
| 54 |
+
testing_correct += (pred_ids == y).type(torch.int).sum().item()
|
| 55 |
+
total_samples += y.size(0)
|
| 56 |
+
|
| 57 |
+
if len(bad_cases) < 20:
|
| 58 |
+
wrong_idx = (pred_ids != y).nonzero()
|
| 59 |
+
for idx in wrong_idx:
|
| 60 |
+
if len(bad_cases) < 20:
|
| 61 |
+
raw_img = X[idx.item()].cpu()
|
| 62 |
+
|
| 63 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 64 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 65 |
+
|
| 66 |
+
img = raw_img * std + mean
|
| 67 |
+
img = torch.clamp(img,0,1)
|
| 68 |
+
|
| 69 |
+
bad_cases.append(
|
| 70 |
+
wandb.Image(img,caption=f"Pred: {pred_ids[idx].item()} | True: {y[idx].item()}")
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
val_epoch_loss = testing_loss / len(data_loader)
|
| 74 |
+
val_epoch_acc = testing_correct / total_samples
|
| 75 |
+
|
| 76 |
+
return val_epoch_loss,val_epoch_acc,bad_cases
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
src/model.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from torchvision.models import resnet18,resnet50,ResNet50_Weights
|
| 5 |
+
|
| 6 |
+
class SimpleCNN(nn.Module):
|
| 7 |
+
def __init__(self, num_inputs=1,input_size=28,num_classes=10, dropout_rate=0.3):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.features = nn.Sequential(
|
| 11 |
+
# Block 1
|
| 12 |
+
nn.Conv2d(num_inputs, 32, kernel_size=3, padding=1),
|
| 13 |
+
nn.BatchNorm2d(32),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.MaxPool2d(2),
|
| 16 |
+
|
| 17 |
+
# Block 2
|
| 18 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| 19 |
+
# 【修复 1】你之前写的是 62,必须是 64 才能匹配上一层的输出
|
| 20 |
+
nn.BatchNorm2d(64),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.MaxPool2d(2),
|
| 23 |
+
)
|
| 24 |
+
final_size = input_size // 4
|
| 25 |
+
flatten_dim = 64 * final_size * final_size
|
| 26 |
+
|
| 27 |
+
self.classifier = nn.Sequential(
|
| 28 |
+
nn.Flatten(),
|
| 29 |
+
# 计算逻辑: 28 -> 14 -> 7,通道 64
|
| 30 |
+
nn.Linear(flatten_dim, 512),
|
| 31 |
+
nn.BatchNorm1d(512),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.Dropout(dropout_rate),
|
| 34 |
+
nn.Linear(512, num_classes),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.apply(self._init_weights)
|
| 38 |
+
|
| 39 |
+
def _init_weights(self, m):
|
| 40 |
+
# 【修复 2】你之前写的是 nn.Linaer (拼写错误),导致全连接层没有被正确初始化!
|
| 41 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 42 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 43 |
+
|
| 44 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
| 45 |
+
nn.init.constant_(m.weight, 1)
|
| 46 |
+
# 【修复 3】你之前写的是 m.weiht (拼写错误),导致偏置没有归零
|
| 47 |
+
nn.init.constant_(m.bias, 0)
|
| 48 |
+
|
| 49 |
+
# 【修复 4】你之前的代码里完全漏掉了 forward 函数!
|
| 50 |
+
# 没有这个函数,模型根本不知道怎么跑数据,虽然不报错(如果没调用),但跑起来就是随机的
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = self.features(x)
|
| 53 |
+
x = self.classifier(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
class ResNet18_CIFAR(nn.Module):
|
| 57 |
+
def __init__(self,num_inputs=3,num_classes=10,dropout_rate=0.0):
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
self.aug = nn.Sequential(
|
| 61 |
+
transforms.RandomHorizontalFlip(),
|
| 62 |
+
transforms.RandomCrop(32,padding=4,padding_mode='reflect'),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.net = resnet18(weights=None)
|
| 66 |
+
|
| 67 |
+
self.net.conv1 = nn.Conv2d(num_inputs,64,kernel_size=3,stride=1,padding=1,bias=False)
|
| 68 |
+
self.net.maxpool = nn.Identity()
|
| 69 |
+
|
| 70 |
+
self.net.fc = nn.Linear(512,num_classes)
|
| 71 |
+
|
| 72 |
+
def forward(self,x):
|
| 73 |
+
if self.training:
|
| 74 |
+
x = self.aug(x)
|
| 75 |
+
return self.net(x)
|
| 76 |
+
|
| 77 |
+
class TransferResNet50(nn.Module):
|
| 78 |
+
def __init__(self, num_classes=10, dropout_rate=0.0):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
print("⬇️ Loading Pre-trained ResNet50 (ImageNet)...")
|
| 82 |
+
# 1. 正确加载权重
|
| 83 |
+
self.net = resnet50(weights=ResNet50_Weights.DEFAULT)
|
| 84 |
+
|
| 85 |
+
# 2. 全网微调 (不冻结)
|
| 86 |
+
# 因为 CIFAR-10 和 ImageNet 差异较大 (清晰度、物体类别),微调 Backbone 是必须的
|
| 87 |
+
# 我们已经在 train.py 里用了极小的 backbone_lr (1e-5) 来保护它,所以这里不需要 freeze
|
| 88 |
+
|
| 89 |
+
# 3. 替换分类头
|
| 90 |
+
num_ftrs = self.net.fc.in_features
|
| 91 |
+
self.net.fc = nn.Sequential(
|
| 92 |
+
nn.Dropout(dropout_rate),
|
| 93 |
+
nn.Linear(num_ftrs, num_classes),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
return self.net(x)
|
src/path.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
def get_project_root() -> Path:
|
| 6 |
+
"""获取项目根目录的绝对路径"""
|
| 7 |
+
# 检查当前文件是否被打包
|
| 8 |
+
if getattr(sys, 'frozen', False):
|
| 9 |
+
# 如果是打包后的可执行文件
|
| 10 |
+
return Path(sys.executable).parent
|
| 11 |
+
else:
|
| 12 |
+
# 开发环境下定位项目根目录
|
| 13 |
+
current_file = Path(__file__).resolve()
|
| 14 |
+
# 返回 src 目录的父目录作为项目根目录
|
| 15 |
+
return current_file.parent.parent
|
| 16 |
+
|
| 17 |
+
PROJECT_ROOT = get_project_root()
|
| 18 |
+
CONFIG_PATH = PROJECT_ROOT / 'config.yaml'
|
| 19 |
+
DATA_DIR = PROJECT_ROOT / 'data'
|
| 20 |
+
MODELS_DIR = PROJECT_ROOT / 'models'
|
| 21 |
+
|
| 22 |
+
for directory in [DATA_DIR, MODELS_DIR]:
|
| 23 |
+
directory.mkdir(exist_ok=True)
|
src/utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
def get_device():
|
| 8 |
+
if torch.cuda.is_available():
|
| 9 |
+
return "cuda"
|
| 10 |
+
elif torch.backends.mps.is_available():
|
| 11 |
+
return "mps"
|
| 12 |
+
else:
|
| 13 |
+
return "cpu"
|
| 14 |
+
|
| 15 |
+
def seed_everthing(seed=42):
|
| 16 |
+
random.seed(seed)
|
| 17 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 18 |
+
np.random.seed(seed)
|
| 19 |
+
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
torch.cuda.manual_seed(seed)
|
| 22 |
+
|
| 23 |
+
torch.backends.cudnn.deterministic = False
|
| 24 |
+
torch.backends.cudnn.benchmark = True
|
| 25 |
+
|
start_sweep.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# === 修改这里 ===
|
| 4 |
+
# 1. 你的 Sweep ID (从 wandb sweep sweep.yaml 命令的输出中获得)
|
| 5 |
+
SWEEP_ID="1217820711-sun-yat-sen-university/cifar10_chanllenge/srzfvp0g"
|
| 6 |
+
|
| 7 |
+
# 2. 你想开几个 Agent (并行数)
|
| 8 |
+
# 你的 9800X3D + 5070Ti 建议开 3 个
|
| 9 |
+
NUM_AGENTS=2
|
| 10 |
+
|
| 11 |
+
# 3. Tmux 会话名称 (随便起)
|
| 12 |
+
SESSION_NAME="sweep_resnet18_try1"
|
| 13 |
+
|
| 14 |
+
# 4. 你的 Conda 环境名
|
| 15 |
+
CONDA_ENV="deep_learning"
|
| 16 |
+
# ===============
|
| 17 |
+
|
| 18 |
+
# 检查是否已经存在同名会话,如果有,先杀掉 (防止报错)
|
| 19 |
+
tmux has-session -t $SESSION_NAME 2>/dev/null
|
| 20 |
+
if [ $? == 0 ]; then
|
| 21 |
+
echo "⚠️ Session $SESSION_NAME already exists. Killing it..."
|
| 22 |
+
tmux kill-session -t $SESSION_NAME
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
# 创建新会话 (后台模式)
|
| 26 |
+
tmux new-session -d -s $SESSION_NAME
|
| 27 |
+
|
| 28 |
+
# 循环创建窗口并运行 Agent
|
| 29 |
+
for ((i=1; i<=NUM_AGENTS; i++)); do
|
| 30 |
+
# 如果不是第一个,就切分屏幕
|
| 31 |
+
if [ $i -gt 1 ]; then
|
| 32 |
+
tmux split-window -t $SESSION_NAME
|
| 33 |
+
tmux select-layout -t $SESSION_NAME tiled
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
# 发送命令:激活环境 -> 运行 Agent
|
| 37 |
+
# C-m 代表回车键
|
| 38 |
+
tmux send-keys -t $SESSION_NAME "conda activate $CONDA_ENV" C-m
|
| 39 |
+
tmux send-keys -t $SESSION_NAME "wandb agent $SWEEP_ID" C-m
|
| 40 |
+
|
| 41 |
+
echo "🚀 Agent $i started..."
|
| 42 |
+
done
|
| 43 |
+
|
| 44 |
+
# 进入 Tmux 界面
|
| 45 |
+
echo "All agents running! Attaching..."
|
| 46 |
+
tmux attach -t $SESSION_NAME
|
sweep.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
program: train.py
|
| 2 |
+
# 注意:Hyperband 必须配合 random 使用,而不是 bayes
|
| 3 |
+
# 因为 Hyperband 依靠随机采样来覆盖搜索空间,然后靠剪枝来提高效率
|
| 4 |
+
method: random
|
| 5 |
+
|
| 6 |
+
project: "cifar10_chanllenge"
|
| 7 |
+
name: "20251210-Hyperband-AdamW-LrSearch"
|
| 8 |
+
description: >
|
| 9 |
+
本次实验目的是为了验证 ResNet18 在 CIFAR-10 上
|
| 10 |
+
使用 AdamW 配合强正则化 (Weight Decay > 0.01) 的效果。
|
| 11 |
+
使用了全显存加载优化。
|
| 12 |
+
run_cap: 100
|
| 13 |
+
command:
|
| 14 |
+
- ${env}
|
| 15 |
+
- ${interpreter}
|
| 16 |
+
- ${program}
|
| 17 |
+
- ${args}
|
| 18 |
+
|
| 19 |
+
metric:
|
| 20 |
+
name: test_epoch_acc
|
| 21 |
+
goal: maximize
|
| 22 |
+
|
| 23 |
+
# 🔥 核心:Hyperband 提前终止策略
|
| 24 |
+
early_terminate:
|
| 25 |
+
type: hyperband
|
| 26 |
+
# 最小迭代次数:跑满 10 个 Epoch 后才开始评估是否要杀掉
|
| 27 |
+
# 避免模型还没热身就被误杀
|
| 28 |
+
min_iter: 10
|
| 29 |
+
# 淘汰比例:每次淘汰 2/3 的落后分子,保留 1/3 进入下一轮
|
| 30 |
+
eta: 3
|
| 31 |
+
|
| 32 |
+
parameters:
|
| 33 |
+
project_name:
|
| 34 |
+
value: "cifar10_hyperband_search"
|
| 35 |
+
|
| 36 |
+
# --- 训练轮数 ---
|
| 37 |
+
train:
|
| 38 |
+
parameters:
|
| 39 |
+
epochs:
|
| 40 |
+
# 这里设置最大轮数。Hyperband 会自动在中间截断
|
| 41 |
+
# 设为 150,保证“幸存者”能跑完全程,收敛到极致
|
| 42 |
+
value: 150
|
| 43 |
+
|
| 44 |
+
# --- 数据参数 ---
|
| 45 |
+
data:
|
| 46 |
+
parameters:
|
| 47 |
+
batch_size:
|
| 48 |
+
# 搜索区间:涵盖了 SGD 喜欢的小 Batch 和 AdamW 喜欢的大 Batch
|
| 49 |
+
values: [256, 512, 1024]
|
| 50 |
+
|
| 51 |
+
# --- 模型参数 ---
|
| 52 |
+
model:
|
| 53 |
+
parameters:
|
| 54 |
+
type:
|
| 55 |
+
value: "ResNet18"
|
| 56 |
+
num_classes:
|
| 57 |
+
value: 10
|
| 58 |
+
dropout_rate:
|
| 59 |
+
# ResNet 自带 BN,通常不需要大 Dropout,搜一个小范围
|
| 60 |
+
distribution: uniform
|
| 61 |
+
min: 0.0
|
| 62 |
+
max: 0.2
|
| 63 |
+
|
| 64 |
+
# --- 优化器 (搜索重点) ---
|
| 65 |
+
optimizer:
|
| 66 |
+
parameters:
|
| 67 |
+
name:
|
| 68 |
+
# 同时尝试 SGD (传统SOTA王者) 和 AdamW (现代万金油)
|
| 69 |
+
values: ['sgd', 'adamw']
|
| 70 |
+
|
| 71 |
+
lr:
|
| 72 |
+
# 学习率跨度要大!因为 SGD 需要 ~0.1,而 AdamW 需要 ~0.001
|
| 73 |
+
# log_uniform_values 会在对数尺度上均匀采样,保证两头都能搜到
|
| 74 |
+
distribution: log_uniform_values
|
| 75 |
+
min: 0.0001
|
| 76 |
+
max: 0.2
|
| 77 |
+
|
| 78 |
+
weight_decay:
|
| 79 |
+
# 正则化力度的搜索
|
| 80 |
+
distribution: log_uniform_values
|
| 81 |
+
min: 1e-4 # 0.0001 (适合 SGD)
|
| 82 |
+
max: 1e-1 # 0.1 (适合 AdamW)
|
train.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
torch.set_float32_matmul_precision('high')
|
| 3 |
+
import os
|
| 4 |
+
import yaml
|
| 5 |
+
import wandb
|
| 6 |
+
from torch import nn
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import sys
|
| 9 |
+
from torch.amp import GradScaler
|
| 10 |
+
|
| 11 |
+
os.environ["CXX"] = "/usr/bin/g++"
|
| 12 |
+
os.environ["CC"] = "/usr/bin/gcc"
|
| 13 |
+
|
| 14 |
+
ROOT_DIR = Path(__file__).resolve().parent
|
| 15 |
+
if ROOT_DIR not in sys.path:
|
| 16 |
+
sys.path.append(str(ROOT_DIR))
|
| 17 |
+
|
| 18 |
+
from src.dataset import get_dataloader
|
| 19 |
+
from src.utils import get_device,seed_everthing
|
| 20 |
+
from src.model import ResNet18_CIFAR,SimpleCNN,TransferResNet50
|
| 21 |
+
from src.engine import train_one_epoch,evaluate
|
| 22 |
+
|
| 23 |
+
def load_yaml(config_path=None):
|
| 24 |
+
if config_path is None:
|
| 25 |
+
config_path = ROOT_DIR / 'config.yaml'
|
| 26 |
+
try:
|
| 27 |
+
with open(config_path,'r',encoding='utf-8') as f:
|
| 28 |
+
config = yaml.safe_load(f)
|
| 29 |
+
return config
|
| 30 |
+
except FileNotFoundError:
|
| 31 |
+
print(f"{config_path} File not found!!")
|
| 32 |
+
exit(1)
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
static_config = load_yaml()
|
| 36 |
+
wandb_cfg = static_config['wandb_setup']
|
| 37 |
+
|
| 38 |
+
wandb.init(
|
| 39 |
+
project=wandb_cfg.get('project','my_project'),
|
| 40 |
+
group=wandb_cfg.get('experiment','default'),
|
| 41 |
+
tags=wandb_cfg.get('tags',[]),
|
| 42 |
+
job_type=wandb_cfg.get('job_type','train'),
|
| 43 |
+
config=static_config,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
cfg = wandb.config
|
| 47 |
+
|
| 48 |
+
relative_save_dir = cfg['train']['save_dir']
|
| 49 |
+
save_dir = (ROOT_DIR / relative_save_dir).resolve()
|
| 50 |
+
os.makedirs(save_dir,exist_ok=True)
|
| 51 |
+
|
| 52 |
+
best_acc = 0.0
|
| 53 |
+
print(f" Save dir: {save_dir}")
|
| 54 |
+
|
| 55 |
+
print(f" Model: {cfg['model']['type']}")
|
| 56 |
+
print(f"Experiment Start! Mode: {'Sweep' if wandb.run.sweep_id else 'Manual'}")
|
| 57 |
+
print(f" Lr: {cfg['optimizer']['lr']}, Batch: {cfg['data']['batch_size']}, Opt: {cfg['optimizer']['name']}")
|
| 58 |
+
|
| 59 |
+
seed_everthing(cfg.get('seed',42))
|
| 60 |
+
device = get_device()
|
| 61 |
+
|
| 62 |
+
relative_data_path = cfg['data']['data_path']
|
| 63 |
+
|
| 64 |
+
absolute_data_path = (ROOT_DIR / relative_data_path).resolve()
|
| 65 |
+
|
| 66 |
+
data_cfg = cfg['data'].copy()
|
| 67 |
+
data_cfg['data_path'] = str(absolute_data_path)
|
| 68 |
+
|
| 69 |
+
print(f'Loading data from {absolute_data_path}...')
|
| 70 |
+
train_loader,test_loader = get_dataloader(data_cfg)
|
| 71 |
+
|
| 72 |
+
# 🔍【听诊器】检查一个 batch 的形状
|
| 73 |
+
dummy_x, dummy_y = next(iter(train_loader))
|
| 74 |
+
print(f"🧐 Inspection - Input Shape: {dummy_x.shape}")
|
| 75 |
+
|
| 76 |
+
model_type = cfg['model']['type']
|
| 77 |
+
num_classes = cfg['model']['num_classes']
|
| 78 |
+
|
| 79 |
+
dropout_rate = cfg['model'].get('dropout_rate',0.0)
|
| 80 |
+
num_inputs = cfg['model'].get('num_inputs',3)
|
| 81 |
+
|
| 82 |
+
input_size = cfg['model'].get('input_size',32)
|
| 83 |
+
|
| 84 |
+
if model_type == 'SimpleCNN':
|
| 85 |
+
model = SimpleCNN(
|
| 86 |
+
num_inputs = num_inputs,
|
| 87 |
+
input_size = input_size,
|
| 88 |
+
num_classes = num_classes,
|
| 89 |
+
dropout_rate = dropout_rate,
|
| 90 |
+
)
|
| 91 |
+
elif model_type == 'ResNet18':
|
| 92 |
+
model = ResNet18_CIFAR(
|
| 93 |
+
num_inputs = num_inputs,
|
| 94 |
+
num_classes = num_classes,
|
| 95 |
+
dropout_rate = dropout_rate,
|
| 96 |
+
)
|
| 97 |
+
elif model_type == 'TransferResNet50':
|
| 98 |
+
model = TransferResNet50(
|
| 99 |
+
num_classes=num_classes,
|
| 100 |
+
dropout_rate=dropout_rate,
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 104 |
+
|
| 105 |
+
model.to(device)
|
| 106 |
+
model = model.to(memory_format=torch.channels_last)
|
| 107 |
+
|
| 108 |
+
if hasattr(model,'net'):
|
| 109 |
+
print("⚡ Compiling ResNet backbone...")
|
| 110 |
+
model.net = torch.compile(model.net,mode='reduce-overhead')
|
| 111 |
+
else:
|
| 112 |
+
print("⚡ Compiling Full Model...")
|
| 113 |
+
model = torch.compile(model,mode='reduce-overhead')
|
| 114 |
+
|
| 115 |
+
opt_cfg = cfg['optimizer']
|
| 116 |
+
opt_name = opt_cfg['name'].lower()
|
| 117 |
+
|
| 118 |
+
# 1. 读取配置中的两个学习率 (务必转为 float)
|
| 119 |
+
lr_head = float(opt_cfg['lr']) # 对应 config 里的 lr
|
| 120 |
+
lr_backbone = float(opt_cfg.get('backbone_lr', lr_head * 0.1)) # 对应 config 里的 backbone_lr,没填默认是 head 的 1/10
|
| 121 |
+
weight_decay = float(opt_cfg.get('weight_decay', 0.0))
|
| 122 |
+
|
| 123 |
+
# 2. 将模型参数分组 (Backbone vs Head)
|
| 124 |
+
# 逻辑:检查参数名里是否包含 "fc" (ResNet 的最后一层通常叫 fc)
|
| 125 |
+
backbone_params = []
|
| 126 |
+
head_params = []
|
| 127 |
+
|
| 128 |
+
for name, param in model.named_parameters():
|
| 129 |
+
if "fc" in name:
|
| 130 |
+
head_params.append(param)
|
| 131 |
+
else:
|
| 132 |
+
backbone_params.append(param)
|
| 133 |
+
|
| 134 |
+
print(f"🔧 Optimizer Setup: Head LR={lr_head}, Backbone LR={lr_backbone}")
|
| 135 |
+
|
| 136 |
+
# 3. 初始化优化器 (传入参数组 list)
|
| 137 |
+
if opt_name == "adam":
|
| 138 |
+
optimizer = torch.optim.Adam([
|
| 139 |
+
{'params': backbone_params, 'lr': lr_backbone},
|
| 140 |
+
{'params': head_params, 'lr': lr_head}
|
| 141 |
+
], weight_decay=weight_decay)
|
| 142 |
+
|
| 143 |
+
elif opt_name == "adamw":
|
| 144 |
+
optimizer = torch.optim.AdamW([
|
| 145 |
+
{'params': backbone_params, 'lr': lr_backbone},
|
| 146 |
+
{'params': head_params, 'lr': lr_head}
|
| 147 |
+
], weight_decay=weight_decay)
|
| 148 |
+
|
| 149 |
+
elif opt_name == "sgd":
|
| 150 |
+
optimizer = torch.optim.SGD([
|
| 151 |
+
{'params': backbone_params, 'lr': lr_backbone},
|
| 152 |
+
{'params': head_params, 'lr': lr_head}
|
| 153 |
+
], momentum=0.9, weight_decay=weight_decay)
|
| 154 |
+
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"不支持的优化器: {opt_name}")
|
| 157 |
+
|
| 158 |
+
scheduler = None
|
| 159 |
+
if 'scheduler' in cfg and cfg['scheduler'].get('use_scheduler',False):
|
| 160 |
+
sch_cfg = cfg['scheduler']
|
| 161 |
+
|
| 162 |
+
if sch_cfg['type'] == 'CosineAnnealingLR':
|
| 163 |
+
t_max = cfg['train']['epochs']
|
| 164 |
+
eta_min = float(sch_cfg.get('eta_min',0.0))
|
| 165 |
+
|
| 166 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 167 |
+
optimizer,
|
| 168 |
+
T_max = t_max,
|
| 169 |
+
eta_min = eta_min,
|
| 170 |
+
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
elif sch_cfg['type'] == 'StepLR':
|
| 174 |
+
step_size = sch_cfg.get('step_size',10)
|
| 175 |
+
gamma = sch_cfg.get('gamma',0.1)
|
| 176 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 177 |
+
optimizer,
|
| 178 |
+
step_size=step_size,
|
| 179 |
+
gamma=gamma,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
print('Not using Learning Rate Scheduler')
|
| 183 |
+
|
| 184 |
+
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 185 |
+
|
| 186 |
+
epochs = cfg['train']['epochs']
|
| 187 |
+
scaler = GradScaler('cuda')
|
| 188 |
+
|
| 189 |
+
for epoch in range(epochs):
|
| 190 |
+
train_epoch_loss,train_epoch_acc = train_one_epoch(epoch,model,train_loader,loss_fn,optimizer,device,scaler)
|
| 191 |
+
val_epoch_loss,val_epoch_acc,bad_cases = evaluate(epoch,model,test_loader,loss_fn,device)
|
| 192 |
+
|
| 193 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 194 |
+
|
| 195 |
+
if scheduler is not None:
|
| 196 |
+
scheduler.step()
|
| 197 |
+
|
| 198 |
+
print(f"Epoch {epoch+1}/{epochs}\t[LR: {current_lr:>.6f}]\tTrain Loss: {train_epoch_loss:>.3f}\tTrain Acc: {train_epoch_acc:>.2%}\t|\tVal Loss: {val_epoch_loss:>.3f}\tVal Acc: {val_epoch_acc:>.2%}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if val_epoch_acc > best_acc:
|
| 202 |
+
best_acc = val_epoch_acc
|
| 203 |
+
save_name = f"{cfg['wandb_setup']['experiment']}_best.pth"
|
| 204 |
+
save_path = save_dir / save_name
|
| 205 |
+
|
| 206 |
+
torch.save(model.state_dict(),save_path)
|
| 207 |
+
|
| 208 |
+
print(f"🌟 New Best Acc: {best_acc:.2f} -> Model save to: {save_path}")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
wandb.log({
|
| 212 |
+
"train_epoch_loss":train_epoch_loss,
|
| 213 |
+
"train_epoch_acc":train_epoch_acc,
|
| 214 |
+
"test_epoch_loss":val_epoch_loss,
|
| 215 |
+
"test_epoch_acc":val_epoch_acc,
|
| 216 |
+
'best_acc':best_acc,
|
| 217 |
+
"bad_cases":bad_cases,
|
| 218 |
+
"learning_rate": current_lr,
|
| 219 |
+
"epoch": epoch,
|
| 220 |
+
})
|
| 221 |
+
wandb.finish()
|
| 222 |
+
|
| 223 |
+
if __name__ == '__main__':
|
| 224 |
+
main()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|