YiMeng-SYSU commited on
Commit
e3469ed
·
verified ·
1 Parent(s): 162c2c6

Initial commit of transfer learning project files

Browse files
config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #网页设置
2
+ wandb_setup:
3
+ project: "cifar10_transfer_learning"
4
+ experiment: "ResNet50_ImageNet_224px"
5
+ tags: ["cifar10","resnet50"]
6
+ seed: 42
7
+ job_type: "train"
8
+
9
+ #数据加载参数
10
+ data:
11
+ data_path: "./data"
12
+ batch_size: 64
13
+ num_workers: 4
14
+ image_size: [224,224]
15
+ in_channels: 3
16
+
17
+ #模型结构参数
18
+ model:
19
+ type: "TransferResNet50"
20
+ dropout_rate: 0.0
21
+ num_classes: 10
22
+
23
+ #训练超参数
24
+ train:
25
+ epochs: 30
26
+ save_dir: "./models"
27
+
28
+ #优化器与调度器
29
+ optimizer:
30
+ name: "adamw"
31
+ lr: 0.001
32
+ backbone_lr: 0.00005
33
+ weight_decay: 1e-3
34
+
35
+ scheduler:
36
+ use_scheduler: True
37
+ type: "CosineAnnealingLR"
38
+ T_max: 30
39
+ eta_min: 1e-6
40
+
41
+
42
+
models/ResNet50_ImageNet_224px_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca15171214c8a6d75fca585e6ce2b16ba4b6b735bbba58d92f729936a4b16a02
3
+ size 94445757
src/__init__.py ADDED
File without changes
src/dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import datasets,transforms
4
+ from src.path import DATA_DIR
5
+
6
+ def get_dataloader(config):
7
+ batch_size = config.get('batch_size',64)
8
+ data_path = config.get('data_path',DATA_DIR)
9
+ num_workers = config.get('num_workers',4)
10
+
11
+ mean = [0.485, 0.456, 0.406]
12
+ std = [0.229, 0.224, 0.225]
13
+
14
+ train_transform = transforms.Compose([
15
+ transforms.Resize((224,224)),
16
+ transforms.RandomHorizontalFlip(),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean,std),
19
+ ])
20
+ val_transform = transforms.Compose([
21
+ transforms.Resize((224,224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean,std),
24
+ ])
25
+
26
+ train_data = datasets.CIFAR10(root=data_path,train=True,download=True,transform=train_transform)
27
+ test_data = datasets.CIFAR10(root=data_path,train=False,download=True,transform=val_transform)
28
+
29
+ train_loader = DataLoader(
30
+ train_data,
31
+ batch_size=batch_size,
32
+ num_workers=num_workers,
33
+ shuffle=True,
34
+ pin_memory=True,
35
+ persistent_workers=True,
36
+ )
37
+
38
+ test_loader = DataLoader(
39
+ test_data,
40
+ batch_size=batch_size,
41
+ num_workers=num_workers,
42
+ shuffle=False,
43
+ pin_memory=True,
44
+ persistent_workers=True,
45
+ )
46
+
47
+ return train_loader,test_loader
src/engine.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import wandb
4
+ from torch.amp import autocast, GradScaler
5
+
6
+ def train_one_epoch(epoch_id,model,data_loader,loss_fn,optimizer,device,scaler):
7
+ model.train()
8
+ training_loss = 0.0
9
+ running_correct = 0
10
+ total_samples = 0
11
+
12
+ for batch,(X,y) in enumerate(data_loader):
13
+ if not X.is_cuda:
14
+ X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)
15
+ X = X.to(memory_format=torch.channels_last)
16
+ optimizer.zero_grad(set_to_none=True)
17
+
18
+ with autocast('cuda',dtype=torch.float16):
19
+ pred = model(X)
20
+ loss = loss_fn(pred,y)
21
+
22
+ scaler.scale(loss).backward()
23
+ scaler.step(optimizer)
24
+ scaler.update()
25
+
26
+ pred_ids = pred.argmax(1)
27
+ running_correct += (pred_ids == y).type(torch.int).sum().item()
28
+ total_samples += y.size(0)
29
+
30
+ training_loss += loss.item()
31
+
32
+ train_epoch_loss = training_loss / len(data_loader)
33
+ train_epoch_acc = running_correct / total_samples
34
+
35
+ return train_epoch_loss,train_epoch_acc
36
+
37
+ def evaluate(epoch_id,model,data_loader,loss_fn,device):
38
+ model.eval()
39
+ testing_loss = 0.0
40
+ testing_correct = 0
41
+ total_samples = 0
42
+ bad_cases = []
43
+
44
+ with torch.no_grad():
45
+ for X,y in data_loader:
46
+ if not X.is_cuda:
47
+ X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)
48
+
49
+ pred = model(X)
50
+ loss = loss_fn(pred,y)
51
+ testing_loss += loss.item()
52
+
53
+ pred_ids = pred.argmax(1)
54
+ testing_correct += (pred_ids == y).type(torch.int).sum().item()
55
+ total_samples += y.size(0)
56
+
57
+ if len(bad_cases) < 20:
58
+ wrong_idx = (pred_ids != y).nonzero()
59
+ for idx in wrong_idx:
60
+ if len(bad_cases) < 20:
61
+ raw_img = X[idx.item()].cpu()
62
+
63
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
64
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
65
+
66
+ img = raw_img * std + mean
67
+ img = torch.clamp(img,0,1)
68
+
69
+ bad_cases.append(
70
+ wandb.Image(img,caption=f"Pred: {pred_ids[idx].item()} | True: {y[idx].item()}")
71
+ )
72
+
73
+ val_epoch_loss = testing_loss / len(data_loader)
74
+ val_epoch_acc = testing_correct / total_samples
75
+
76
+ return val_epoch_loss,val_epoch_acc,bad_cases
77
+
78
+
79
+
src/model.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import transforms
4
+ from torchvision.models import resnet18,resnet50,ResNet50_Weights
5
+
6
+ class SimpleCNN(nn.Module):
7
+ def __init__(self, num_inputs=1,input_size=28,num_classes=10, dropout_rate=0.3):
8
+ super().__init__()
9
+
10
+ self.features = nn.Sequential(
11
+ # Block 1
12
+ nn.Conv2d(num_inputs, 32, kernel_size=3, padding=1),
13
+ nn.BatchNorm2d(32),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(2),
16
+
17
+ # Block 2
18
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
19
+ # 【修复 1】你之前写的是 62,必须是 64 才能匹配上一层的输出
20
+ nn.BatchNorm2d(64),
21
+ nn.ReLU(),
22
+ nn.MaxPool2d(2),
23
+ )
24
+ final_size = input_size // 4
25
+ flatten_dim = 64 * final_size * final_size
26
+
27
+ self.classifier = nn.Sequential(
28
+ nn.Flatten(),
29
+ # 计算逻辑: 28 -> 14 -> 7,通道 64
30
+ nn.Linear(flatten_dim, 512),
31
+ nn.BatchNorm1d(512),
32
+ nn.ReLU(),
33
+ nn.Dropout(dropout_rate),
34
+ nn.Linear(512, num_classes),
35
+ )
36
+
37
+ self.apply(self._init_weights)
38
+
39
+ def _init_weights(self, m):
40
+ # 【修复 2】你之前写的是 nn.Linaer (拼写错误),导致全连接层没有被正确初始化!
41
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
42
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
43
+
44
+ elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
45
+ nn.init.constant_(m.weight, 1)
46
+ # 【修复 3】你之前写的是 m.weiht (拼写错误),导致偏置没有归零
47
+ nn.init.constant_(m.bias, 0)
48
+
49
+ # 【修复 4】你之前的代码里完全漏掉了 forward 函数!
50
+ # 没有这个函数,模型根本不知道怎么跑数据,虽然不报错(如果没调用),但跑起来就是随机的
51
+ def forward(self, x):
52
+ x = self.features(x)
53
+ x = self.classifier(x)
54
+ return x
55
+
56
+ class ResNet18_CIFAR(nn.Module):
57
+ def __init__(self,num_inputs=3,num_classes=10,dropout_rate=0.0):
58
+ super().__init__()
59
+
60
+ self.aug = nn.Sequential(
61
+ transforms.RandomHorizontalFlip(),
62
+ transforms.RandomCrop(32,padding=4,padding_mode='reflect'),
63
+ )
64
+
65
+ self.net = resnet18(weights=None)
66
+
67
+ self.net.conv1 = nn.Conv2d(num_inputs,64,kernel_size=3,stride=1,padding=1,bias=False)
68
+ self.net.maxpool = nn.Identity()
69
+
70
+ self.net.fc = nn.Linear(512,num_classes)
71
+
72
+ def forward(self,x):
73
+ if self.training:
74
+ x = self.aug(x)
75
+ return self.net(x)
76
+
77
+ class TransferResNet50(nn.Module):
78
+ def __init__(self, num_classes=10, dropout_rate=0.0):
79
+ super().__init__()
80
+
81
+ print("⬇️ Loading Pre-trained ResNet50 (ImageNet)...")
82
+ # 1. 正确加载权重
83
+ self.net = resnet50(weights=ResNet50_Weights.DEFAULT)
84
+
85
+ # 2. 全网微调 (不冻结)
86
+ # 因为 CIFAR-10 和 ImageNet 差异较大 (清晰度、物体类别),微调 Backbone 是必须的
87
+ # 我们已经在 train.py 里用了极小的 backbone_lr (1e-5) 来保护它,所以这里不需要 freeze
88
+
89
+ # 3. 替换分类头
90
+ num_ftrs = self.net.fc.in_features
91
+ self.net.fc = nn.Sequential(
92
+ nn.Dropout(dropout_rate),
93
+ nn.Linear(num_ftrs, num_classes),
94
+ )
95
+
96
+ def forward(self, x):
97
+ return self.net(x)
src/path.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import sys
4
+
5
+ def get_project_root() -> Path:
6
+ """获取项目根目录的绝对路径"""
7
+ # 检查当前文件是否被打包
8
+ if getattr(sys, 'frozen', False):
9
+ # 如果是打包后的可执行文件
10
+ return Path(sys.executable).parent
11
+ else:
12
+ # 开发环境下定位项目根目录
13
+ current_file = Path(__file__).resolve()
14
+ # 返回 src 目录的父目录作为项目根目录
15
+ return current_file.parent.parent
16
+
17
+ PROJECT_ROOT = get_project_root()
18
+ CONFIG_PATH = PROJECT_ROOT / 'config.yaml'
19
+ DATA_DIR = PROJECT_ROOT / 'data'
20
+ MODELS_DIR = PROJECT_ROOT / 'models'
21
+
22
+ for directory in [DATA_DIR, MODELS_DIR]:
23
+ directory.mkdir(exist_ok=True)
src/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+
7
+ def get_device():
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.backends.mps.is_available():
11
+ return "mps"
12
+ else:
13
+ return "cpu"
14
+
15
+ def seed_everthing(seed=42):
16
+ random.seed(seed)
17
+ os.environ['PYTHONHASHSEED'] = str(seed)
18
+ np.random.seed(seed)
19
+
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+
23
+ torch.backends.cudnn.deterministic = False
24
+ torch.backends.cudnn.benchmark = True
25
+
start_sweep.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # === 修改这里 ===
4
+ # 1. 你的 Sweep ID (从 wandb sweep sweep.yaml 命令的输出中获得)
5
+ SWEEP_ID="1217820711-sun-yat-sen-university/cifar10_chanllenge/srzfvp0g"
6
+
7
+ # 2. 你想开几个 Agent (并行数)
8
+ # 你的 9800X3D + 5070Ti 建议开 3 个
9
+ NUM_AGENTS=2
10
+
11
+ # 3. Tmux 会话名称 (随便起)
12
+ SESSION_NAME="sweep_resnet18_try1"
13
+
14
+ # 4. 你的 Conda 环境名
15
+ CONDA_ENV="deep_learning"
16
+ # ===============
17
+
18
+ # 检查是否已经存在同名会话,如果有,先杀掉 (防止报错)
19
+ tmux has-session -t $SESSION_NAME 2>/dev/null
20
+ if [ $? == 0 ]; then
21
+ echo "⚠️ Session $SESSION_NAME already exists. Killing it..."
22
+ tmux kill-session -t $SESSION_NAME
23
+ fi
24
+
25
+ # 创建新会话 (后台模式)
26
+ tmux new-session -d -s $SESSION_NAME
27
+
28
+ # 循环创建窗口并运行 Agent
29
+ for ((i=1; i<=NUM_AGENTS; i++)); do
30
+ # 如果不是第一个,就切分屏幕
31
+ if [ $i -gt 1 ]; then
32
+ tmux split-window -t $SESSION_NAME
33
+ tmux select-layout -t $SESSION_NAME tiled
34
+ fi
35
+
36
+ # 发送命令:激活环境 -> 运行 Agent
37
+ # C-m 代表回车键
38
+ tmux send-keys -t $SESSION_NAME "conda activate $CONDA_ENV" C-m
39
+ tmux send-keys -t $SESSION_NAME "wandb agent $SWEEP_ID" C-m
40
+
41
+ echo "🚀 Agent $i started..."
42
+ done
43
+
44
+ # 进入 Tmux 界面
45
+ echo "All agents running! Attaching..."
46
+ tmux attach -t $SESSION_NAME
sweep.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program: train.py
2
+ # 注意:Hyperband 必须配合 random 使用,而不是 bayes
3
+ # 因为 Hyperband 依靠随机采样来覆盖搜索空间,然后靠剪枝来提高效率
4
+ method: random
5
+
6
+ project: "cifar10_chanllenge"
7
+ name: "20251210-Hyperband-AdamW-LrSearch"
8
+ description: >
9
+ 本次实验目的是为了验证 ResNet18 在 CIFAR-10 上
10
+ 使用 AdamW 配合强正则化 (Weight Decay > 0.01) 的效果。
11
+ 使用了全显存加载优化。
12
+ run_cap: 100
13
+ command:
14
+ - ${env}
15
+ - ${interpreter}
16
+ - ${program}
17
+ - ${args}
18
+
19
+ metric:
20
+ name: test_epoch_acc
21
+ goal: maximize
22
+
23
+ # 🔥 核心:Hyperband 提前终止策略
24
+ early_terminate:
25
+ type: hyperband
26
+ # 最小迭代次数:跑满 10 个 Epoch 后才开始评估是否要杀掉
27
+ # 避免模型还没热身就被误杀
28
+ min_iter: 10
29
+ # 淘汰比例:每次淘汰 2/3 的落后分子,保留 1/3 进入下一轮
30
+ eta: 3
31
+
32
+ parameters:
33
+ project_name:
34
+ value: "cifar10_hyperband_search"
35
+
36
+ # --- 训练轮数 ---
37
+ train:
38
+ parameters:
39
+ epochs:
40
+ # 这里设置最大轮数。Hyperband 会自动在中间截断
41
+ # 设为 150,保证“幸存者”能跑完全程,收敛到极致
42
+ value: 150
43
+
44
+ # --- 数据参数 ---
45
+ data:
46
+ parameters:
47
+ batch_size:
48
+ # 搜索区间:涵盖了 SGD 喜欢的小 Batch 和 AdamW 喜欢的大 Batch
49
+ values: [256, 512, 1024]
50
+
51
+ # --- 模型参数 ---
52
+ model:
53
+ parameters:
54
+ type:
55
+ value: "ResNet18"
56
+ num_classes:
57
+ value: 10
58
+ dropout_rate:
59
+ # ResNet 自带 BN,通常不需要大 Dropout,搜一个小范围
60
+ distribution: uniform
61
+ min: 0.0
62
+ max: 0.2
63
+
64
+ # --- 优化器 (搜索重点) ---
65
+ optimizer:
66
+ parameters:
67
+ name:
68
+ # 同时尝试 SGD (传统SOTA王者) 和 AdamW (现代万金油)
69
+ values: ['sgd', 'adamw']
70
+
71
+ lr:
72
+ # 学习率跨度要大!因为 SGD 需要 ~0.1,而 AdamW 需要 ~0.001
73
+ # log_uniform_values 会在对数尺度上均匀采样,保证两头都能搜到
74
+ distribution: log_uniform_values
75
+ min: 0.0001
76
+ max: 0.2
77
+
78
+ weight_decay:
79
+ # 正则化力度的搜索
80
+ distribution: log_uniform_values
81
+ min: 1e-4 # 0.0001 (适合 SGD)
82
+ max: 1e-1 # 0.1 (适合 AdamW)
train.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.set_float32_matmul_precision('high')
3
+ import os
4
+ import yaml
5
+ import wandb
6
+ from torch import nn
7
+ from pathlib import Path
8
+ import sys
9
+ from torch.amp import GradScaler
10
+
11
+ os.environ["CXX"] = "/usr/bin/g++"
12
+ os.environ["CC"] = "/usr/bin/gcc"
13
+
14
+ ROOT_DIR = Path(__file__).resolve().parent
15
+ if ROOT_DIR not in sys.path:
16
+ sys.path.append(str(ROOT_DIR))
17
+
18
+ from src.dataset import get_dataloader
19
+ from src.utils import get_device,seed_everthing
20
+ from src.model import ResNet18_CIFAR,SimpleCNN,TransferResNet50
21
+ from src.engine import train_one_epoch,evaluate
22
+
23
+ def load_yaml(config_path=None):
24
+ if config_path is None:
25
+ config_path = ROOT_DIR / 'config.yaml'
26
+ try:
27
+ with open(config_path,'r',encoding='utf-8') as f:
28
+ config = yaml.safe_load(f)
29
+ return config
30
+ except FileNotFoundError:
31
+ print(f"{config_path} File not found!!")
32
+ exit(1)
33
+
34
+ def main():
35
+ static_config = load_yaml()
36
+ wandb_cfg = static_config['wandb_setup']
37
+
38
+ wandb.init(
39
+ project=wandb_cfg.get('project','my_project'),
40
+ group=wandb_cfg.get('experiment','default'),
41
+ tags=wandb_cfg.get('tags',[]),
42
+ job_type=wandb_cfg.get('job_type','train'),
43
+ config=static_config,
44
+ )
45
+
46
+ cfg = wandb.config
47
+
48
+ relative_save_dir = cfg['train']['save_dir']
49
+ save_dir = (ROOT_DIR / relative_save_dir).resolve()
50
+ os.makedirs(save_dir,exist_ok=True)
51
+
52
+ best_acc = 0.0
53
+ print(f" Save dir: {save_dir}")
54
+
55
+ print(f" Model: {cfg['model']['type']}")
56
+ print(f"Experiment Start! Mode: {'Sweep' if wandb.run.sweep_id else 'Manual'}")
57
+ print(f" Lr: {cfg['optimizer']['lr']}, Batch: {cfg['data']['batch_size']}, Opt: {cfg['optimizer']['name']}")
58
+
59
+ seed_everthing(cfg.get('seed',42))
60
+ device = get_device()
61
+
62
+ relative_data_path = cfg['data']['data_path']
63
+
64
+ absolute_data_path = (ROOT_DIR / relative_data_path).resolve()
65
+
66
+ data_cfg = cfg['data'].copy()
67
+ data_cfg['data_path'] = str(absolute_data_path)
68
+
69
+ print(f'Loading data from {absolute_data_path}...')
70
+ train_loader,test_loader = get_dataloader(data_cfg)
71
+
72
+ # 🔍【听诊器】检查一个 batch 的形状
73
+ dummy_x, dummy_y = next(iter(train_loader))
74
+ print(f"🧐 Inspection - Input Shape: {dummy_x.shape}")
75
+
76
+ model_type = cfg['model']['type']
77
+ num_classes = cfg['model']['num_classes']
78
+
79
+ dropout_rate = cfg['model'].get('dropout_rate',0.0)
80
+ num_inputs = cfg['model'].get('num_inputs',3)
81
+
82
+ input_size = cfg['model'].get('input_size',32)
83
+
84
+ if model_type == 'SimpleCNN':
85
+ model = SimpleCNN(
86
+ num_inputs = num_inputs,
87
+ input_size = input_size,
88
+ num_classes = num_classes,
89
+ dropout_rate = dropout_rate,
90
+ )
91
+ elif model_type == 'ResNet18':
92
+ model = ResNet18_CIFAR(
93
+ num_inputs = num_inputs,
94
+ num_classes = num_classes,
95
+ dropout_rate = dropout_rate,
96
+ )
97
+ elif model_type == 'TransferResNet50':
98
+ model = TransferResNet50(
99
+ num_classes=num_classes,
100
+ dropout_rate=dropout_rate,
101
+ )
102
+ else:
103
+ raise ValueError(f"Unknown model type: {model_type}")
104
+
105
+ model.to(device)
106
+ model = model.to(memory_format=torch.channels_last)
107
+
108
+ if hasattr(model,'net'):
109
+ print("⚡ Compiling ResNet backbone...")
110
+ model.net = torch.compile(model.net,mode='reduce-overhead')
111
+ else:
112
+ print("⚡ Compiling Full Model...")
113
+ model = torch.compile(model,mode='reduce-overhead')
114
+
115
+ opt_cfg = cfg['optimizer']
116
+ opt_name = opt_cfg['name'].lower()
117
+
118
+ # 1. 读取配置中的两个学习率 (务必转为 float)
119
+ lr_head = float(opt_cfg['lr']) # 对应 config 里的 lr
120
+ lr_backbone = float(opt_cfg.get('backbone_lr', lr_head * 0.1)) # 对应 config 里的 backbone_lr,没填默认是 head 的 1/10
121
+ weight_decay = float(opt_cfg.get('weight_decay', 0.0))
122
+
123
+ # 2. 将模型参数分组 (Backbone vs Head)
124
+ # 逻辑:检查参数名里是否包含 "fc" (ResNet 的最后一层通常叫 fc)
125
+ backbone_params = []
126
+ head_params = []
127
+
128
+ for name, param in model.named_parameters():
129
+ if "fc" in name:
130
+ head_params.append(param)
131
+ else:
132
+ backbone_params.append(param)
133
+
134
+ print(f"🔧 Optimizer Setup: Head LR={lr_head}, Backbone LR={lr_backbone}")
135
+
136
+ # 3. 初始化优化器 (传入参数组 list)
137
+ if opt_name == "adam":
138
+ optimizer = torch.optim.Adam([
139
+ {'params': backbone_params, 'lr': lr_backbone},
140
+ {'params': head_params, 'lr': lr_head}
141
+ ], weight_decay=weight_decay)
142
+
143
+ elif opt_name == "adamw":
144
+ optimizer = torch.optim.AdamW([
145
+ {'params': backbone_params, 'lr': lr_backbone},
146
+ {'params': head_params, 'lr': lr_head}
147
+ ], weight_decay=weight_decay)
148
+
149
+ elif opt_name == "sgd":
150
+ optimizer = torch.optim.SGD([
151
+ {'params': backbone_params, 'lr': lr_backbone},
152
+ {'params': head_params, 'lr': lr_head}
153
+ ], momentum=0.9, weight_decay=weight_decay)
154
+
155
+ else:
156
+ raise ValueError(f"不支持的优化器: {opt_name}")
157
+
158
+ scheduler = None
159
+ if 'scheduler' in cfg and cfg['scheduler'].get('use_scheduler',False):
160
+ sch_cfg = cfg['scheduler']
161
+
162
+ if sch_cfg['type'] == 'CosineAnnealingLR':
163
+ t_max = cfg['train']['epochs']
164
+ eta_min = float(sch_cfg.get('eta_min',0.0))
165
+
166
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
167
+ optimizer,
168
+ T_max = t_max,
169
+ eta_min = eta_min,
170
+
171
+ )
172
+
173
+ elif sch_cfg['type'] == 'StepLR':
174
+ step_size = sch_cfg.get('step_size',10)
175
+ gamma = sch_cfg.get('gamma',0.1)
176
+ scheduler = torch.optim.lr_scheduler.StepLR(
177
+ optimizer,
178
+ step_size=step_size,
179
+ gamma=gamma,
180
+ )
181
+ else:
182
+ print('Not using Learning Rate Scheduler')
183
+
184
+ loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
185
+
186
+ epochs = cfg['train']['epochs']
187
+ scaler = GradScaler('cuda')
188
+
189
+ for epoch in range(epochs):
190
+ train_epoch_loss,train_epoch_acc = train_one_epoch(epoch,model,train_loader,loss_fn,optimizer,device,scaler)
191
+ val_epoch_loss,val_epoch_acc,bad_cases = evaluate(epoch,model,test_loader,loss_fn,device)
192
+
193
+ current_lr = optimizer.param_groups[0]['lr']
194
+
195
+ if scheduler is not None:
196
+ scheduler.step()
197
+
198
+ print(f"Epoch {epoch+1}/{epochs}\t[LR: {current_lr:>.6f}]\tTrain Loss: {train_epoch_loss:>.3f}\tTrain Acc: {train_epoch_acc:>.2%}\t|\tVal Loss: {val_epoch_loss:>.3f}\tVal Acc: {val_epoch_acc:>.2%}")
199
+
200
+
201
+ if val_epoch_acc > best_acc:
202
+ best_acc = val_epoch_acc
203
+ save_name = f"{cfg['wandb_setup']['experiment']}_best.pth"
204
+ save_path = save_dir / save_name
205
+
206
+ torch.save(model.state_dict(),save_path)
207
+
208
+ print(f"🌟 New Best Acc: {best_acc:.2f} -> Model save to: {save_path}")
209
+
210
+
211
+ wandb.log({
212
+ "train_epoch_loss":train_epoch_loss,
213
+ "train_epoch_acc":train_epoch_acc,
214
+ "test_epoch_loss":val_epoch_loss,
215
+ "test_epoch_acc":val_epoch_acc,
216
+ 'best_acc':best_acc,
217
+ "bad_cases":bad_cases,
218
+ "learning_rate": current_lr,
219
+ "epoch": epoch,
220
+ })
221
+ wandb.finish()
222
+
223
+ if __name__ == '__main__':
224
+ main()
225
+
226
+
227
+
228
+
229
+