import torch torch.set_float32_matmul_precision('high') import os import yaml import wandb from torch import nn from pathlib import Path import sys from torch.amp import GradScaler os.environ["CXX"] = "/usr/bin/g++" os.environ["CC"] = "/usr/bin/gcc" ROOT_DIR = Path(__file__).resolve().parent if ROOT_DIR not in sys.path: sys.path.append(str(ROOT_DIR)) from src.dataset import get_dataloader from src.utils import get_device,seed_everthing from src.model import ResNet18_CIFAR,SimpleCNN,TransferResNet50 from src.engine import train_one_epoch,evaluate def load_yaml(config_path=None): if config_path is None: config_path = ROOT_DIR / 'config.yaml' try: with open(config_path,'r',encoding='utf-8') as f: config = yaml.safe_load(f) return config except FileNotFoundError: print(f"{config_path} File not found!!") exit(1) def main(): static_config = load_yaml() wandb_cfg = static_config['wandb_setup'] wandb.init( project=wandb_cfg.get('project','my_project'), group=wandb_cfg.get('experiment','default'), tags=wandb_cfg.get('tags',[]), job_type=wandb_cfg.get('job_type','train'), config=static_config, ) cfg = wandb.config relative_save_dir = cfg['train']['save_dir'] save_dir = (ROOT_DIR / relative_save_dir).resolve() os.makedirs(save_dir,exist_ok=True) best_acc = 0.0 print(f" Save dir: {save_dir}") print(f" Model: {cfg['model']['type']}") print(f"Experiment Start! Mode: {'Sweep' if wandb.run.sweep_id else 'Manual'}") print(f" Lr: {cfg['optimizer']['lr']}, Batch: {cfg['data']['batch_size']}, Opt: {cfg['optimizer']['name']}") seed_everthing(cfg.get('seed',42)) device = get_device() relative_data_path = cfg['data']['data_path'] absolute_data_path = (ROOT_DIR / relative_data_path).resolve() data_cfg = cfg['data'].copy() data_cfg['data_path'] = str(absolute_data_path) print(f'Loading data from {absolute_data_path}...') train_loader,test_loader = get_dataloader(data_cfg) # 🔍【听诊器】检查一个 batch 的形状 dummy_x, dummy_y = next(iter(train_loader)) print(f"🧐 Inspection - Input Shape: {dummy_x.shape}") model_type = cfg['model']['type'] num_classes = cfg['model']['num_classes'] dropout_rate = cfg['model'].get('dropout_rate',0.0) num_inputs = cfg['model'].get('num_inputs',3) input_size = cfg['model'].get('input_size',32) if model_type == 'SimpleCNN': model = SimpleCNN( num_inputs = num_inputs, input_size = input_size, num_classes = num_classes, dropout_rate = dropout_rate, ) elif model_type == 'ResNet18': model = ResNet18_CIFAR( num_inputs = num_inputs, num_classes = num_classes, dropout_rate = dropout_rate, ) elif model_type == 'TransferResNet50': model = TransferResNet50( num_classes=num_classes, dropout_rate=dropout_rate, ) else: raise ValueError(f"Unknown model type: {model_type}") model.to(device) model = model.to(memory_format=torch.channels_last) if hasattr(model,'net'): print("⚡ Compiling ResNet backbone...") model.net = torch.compile(model.net,mode='reduce-overhead') else: print("⚡ Compiling Full Model...") model = torch.compile(model,mode='reduce-overhead') opt_cfg = cfg['optimizer'] opt_name = opt_cfg['name'].lower() # 1. 读取配置中的两个学习率 (务必转为 float) lr_head = float(opt_cfg['lr']) # 对应 config 里的 lr lr_backbone = float(opt_cfg.get('backbone_lr', lr_head * 0.1)) # 对应 config 里的 backbone_lr,没填默认是 head 的 1/10 weight_decay = float(opt_cfg.get('weight_decay', 0.0)) # 2. 将模型参数分组 (Backbone vs Head) # 逻辑:检查参数名里是否包含 "fc" (ResNet 的最后一层通常叫 fc) backbone_params = [] head_params = [] for name, param in model.named_parameters(): if "fc" in name: head_params.append(param) else: backbone_params.append(param) print(f"🔧 Optimizer Setup: Head LR={lr_head}, Backbone LR={lr_backbone}") # 3. 初始化优化器 (传入参数组 list) if opt_name == "adam": optimizer = torch.optim.Adam([ {'params': backbone_params, 'lr': lr_backbone}, {'params': head_params, 'lr': lr_head} ], weight_decay=weight_decay) elif opt_name == "adamw": optimizer = torch.optim.AdamW([ {'params': backbone_params, 'lr': lr_backbone}, {'params': head_params, 'lr': lr_head} ], weight_decay=weight_decay) elif opt_name == "sgd": optimizer = torch.optim.SGD([ {'params': backbone_params, 'lr': lr_backbone}, {'params': head_params, 'lr': lr_head} ], momentum=0.9, weight_decay=weight_decay) else: raise ValueError(f"不支持的优化器: {opt_name}") scheduler = None if 'scheduler' in cfg and cfg['scheduler'].get('use_scheduler',False): sch_cfg = cfg['scheduler'] if sch_cfg['type'] == 'CosineAnnealingLR': t_max = cfg['train']['epochs'] eta_min = float(sch_cfg.get('eta_min',0.0)) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max = t_max, eta_min = eta_min, ) elif sch_cfg['type'] == 'StepLR': step_size = sch_cfg.get('step_size',10) gamma = sch_cfg.get('gamma',0.1) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=step_size, gamma=gamma, ) else: print('Not using Learning Rate Scheduler') loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) epochs = cfg['train']['epochs'] scaler = GradScaler('cuda') for epoch in range(epochs): train_epoch_loss,train_epoch_acc = train_one_epoch(epoch,model,train_loader,loss_fn,optimizer,device,scaler) val_epoch_loss,val_epoch_acc,bad_cases = evaluate(epoch,model,test_loader,loss_fn,device) current_lr = optimizer.param_groups[0]['lr'] if scheduler is not None: scheduler.step() print(f"Epoch {epoch+1}/{epochs}\t[LR: {current_lr:>.6f}]\tTrain Loss: {train_epoch_loss:>.3f}\tTrain Acc: {train_epoch_acc:>.2%}\t|\tVal Loss: {val_epoch_loss:>.3f}\tVal Acc: {val_epoch_acc:>.2%}") if val_epoch_acc > best_acc: best_acc = val_epoch_acc save_name = f"{cfg['wandb_setup']['experiment']}_best.pth" save_path = save_dir / save_name torch.save(model.state_dict(),save_path) print(f"🌟 New Best Acc: {best_acc:.2f} -> Model save to: {save_path}") wandb.log({ "train_epoch_loss":train_epoch_loss, "train_epoch_acc":train_epoch_acc, "test_epoch_loss":val_epoch_loss, "test_epoch_acc":val_epoch_acc, 'best_acc':best_acc, "bad_cases":bad_cases, "learning_rate": current_lr, "epoch": epoch, }) wandb.finish() if __name__ == '__main__': main()