| import torch |
| torch.set_float32_matmul_precision('high') |
| import os |
| import yaml |
| import wandb |
| from torch import nn |
| from pathlib import Path |
| import sys |
| from torch.amp import GradScaler |
|
|
| os.environ["CXX"] = "/usr/bin/g++" |
| os.environ["CC"] = "/usr/bin/gcc" |
|
|
| ROOT_DIR = Path(__file__).resolve().parent |
| if ROOT_DIR not in sys.path: |
| sys.path.append(str(ROOT_DIR)) |
|
|
| from src.dataset import get_dataloader |
| from src.utils import get_device,seed_everthing |
| from src.model import TimmModel |
| from src.engine import train_one_epoch,evaluate |
|
|
| def load_yaml(config_path=None): |
| if config_path is None: |
| config_path = ROOT_DIR / 'config.yaml' |
| try: |
| with open(config_path,'r',encoding='utf-8') as f: |
| config = yaml.safe_load(f) |
| return config |
| except FileNotFoundError: |
| print(f"{config_path} File not found!!") |
| exit(1) |
|
|
| def main(): |
| static_config = load_yaml() |
| wandb_cfg = static_config['wandb_setup'] |
|
|
| wandb.init( |
| project=wandb_cfg.get('project','my_project'), |
| group=wandb_cfg.get('experiment','default'), |
| tags=wandb_cfg.get('tags',[]), |
| job_type=wandb_cfg.get('job_type','train'), |
| config=static_config, |
| ) |
|
|
| cfg = wandb.config |
|
|
| relative_save_dir = cfg['train']['save_dir'] |
| save_dir = (ROOT_DIR / relative_save_dir).resolve() |
| os.makedirs(save_dir,exist_ok=True) |
|
|
| best_acc = 0.0 |
| print(f" Save dir: {save_dir}") |
|
|
| print(f" Model: {cfg['model']['type']}") |
| print(f"Experiment Start! Mode: {'Sweep' if wandb.run.sweep_id else 'Manual'}") |
| print(f" Head_lr: {cfg['optimizer']['lr']}, Backbone_lr: {cfg['optimizer']['backbone_lr']} Batch: {cfg['data']['batch_size']}, Opt: {cfg['optimizer']['name']}") |
|
|
| seed_everthing(cfg.get('seed',42)) |
| device = get_device() |
|
|
| relative_data_path = cfg['data']['data_path'] |
|
|
| absolute_data_path = (ROOT_DIR / relative_data_path).resolve() |
|
|
| data_cfg = cfg['data'].copy() |
| data_cfg['data_path'] = str(absolute_data_path) |
|
|
| print(f'Loading data from {absolute_data_path}...') |
| train_loader,test_loader = get_dataloader(data_cfg) |
|
|
| |
| dummy_x, dummy_y = next(iter(train_loader)) |
| print(f"🧐 Inspection - Input Shape: {dummy_x.shape}") |
|
|
| model_type = cfg['model']['type'] |
| num_classes = cfg['model']['num_classes'] |
|
|
| dropout_rate = cfg['model'].get('dropout_rate',0.0) |
| num_inputs = cfg['model'].get('num_inputs',3) |
|
|
| input_size = cfg['model'].get('input_size',32) |
|
|
| if model_type == "TimmModel": |
| model = TimmModel( |
| model_name=cfg['model']['model_name'], |
| num_classes=num_classes, |
| dropout_rate=dropout_rate, |
| ) |
| else: |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
| model.to(device) |
| model = model.to(memory_format=torch.channels_last) |
|
|
| if hasattr(model,'net'): |
| print(f"⚡ Compiling {cfg['model']['model_name']} backbone...") |
| model.net = torch.compile(model.net,mode='reduce-overhead') |
| else: |
| print("⚡ Compiling Full Model...") |
| model = torch.compile(model,mode='reduce-overhead') |
|
|
| opt_cfg = cfg['optimizer'] |
| opt_name = opt_cfg['name'].lower() |
|
|
| |
| lr_head = float(opt_cfg['lr']) |
| lr_backbone = float(opt_cfg.get('backbone_lr', lr_head * 0.1)) |
| weight_decay = float(opt_cfg.get('weight_decay', 0.0)) |
|
|
| |
| |
| backbone_params = [] |
| head_params = [] |
| |
| for name, param in model.named_parameters(): |
| if "head" in name or "fc" in name: |
| head_params.append(param) |
| else: |
| backbone_params.append(param) |
|
|
| print(f"🔧 Optimizer: Backbone params: {len(backbone_params)}, Head params: {len(head_params)}") |
|
|
| |
| if opt_name == "adam": |
| optimizer = torch.optim.Adam([ |
| {'params': backbone_params, 'lr': lr_backbone}, |
| {'params': head_params, 'lr': lr_head} |
| ], weight_decay=weight_decay) |
| |
| elif opt_name == "adamw": |
| optimizer = torch.optim.AdamW([ |
| {'params': backbone_params, 'lr': lr_backbone}, |
| {'params': head_params, 'lr': lr_head} |
| ], weight_decay=weight_decay) |
| |
| elif opt_name == "sgd": |
| optimizer = torch.optim.SGD([ |
| {'params': backbone_params, 'lr': lr_backbone}, |
| {'params': head_params, 'lr': lr_head} |
| ], momentum=0.9, weight_decay=weight_decay) |
| |
| else: |
| raise ValueError(f"不支持的优化器: {opt_name}") |
| |
| scheduler = None |
| if 'scheduler' in cfg and cfg['scheduler'].get('use_scheduler',False): |
| sch_cfg = cfg['scheduler'] |
|
|
| if sch_cfg['type'] == 'CosineAnnealingLR': |
| t_max = cfg['train']['epochs'] |
| eta_min = float(sch_cfg.get('eta_min',0.0)) |
|
|
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max = t_max, |
| eta_min = eta_min, |
|
|
| ) |
| |
| elif sch_cfg['type'] == 'StepLR': |
| step_size = sch_cfg.get('step_size',10) |
| gamma = sch_cfg.get('gamma',0.1) |
| scheduler = torch.optim.lr_scheduler.StepLR( |
| optimizer, |
| step_size=step_size, |
| gamma=gamma, |
| ) |
| else: |
| print('Not using Learning Rate Scheduler') |
| |
| loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) |
|
|
| epochs = cfg['train']['epochs'] |
| scaler = GradScaler('cuda') |
|
|
| for epoch in range(epochs): |
| train_epoch_loss,train_epoch_acc = train_one_epoch(epoch,model,train_loader,loss_fn,optimizer,device,scaler) |
| val_epoch_loss,val_epoch_acc,bad_cases = evaluate(epoch,model,test_loader,loss_fn,device) |
|
|
| lr_backbone = optimizer.param_groups[0]['lr'] |
| lr_head = optimizer.param_groups[1]['lr'] |
|
|
| if scheduler is not None: |
| scheduler.step() |
|
|
| print(f"Epoch {epoch+1}/{epochs}\t[Head LR: {lr_head:>.6f} | Back Lr: {lr_backbone:.8f}]\tTrain Loss: {train_epoch_loss:>.3f}\tTrain Acc: {train_epoch_acc:>.2%}\t|\tVal Loss: {val_epoch_loss:>.3f}\tVal Acc: {val_epoch_acc:>.2%}") |
|
|
|
|
| if val_epoch_acc > best_acc: |
| best_acc = val_epoch_acc |
| save_name = f"{cfg['wandb_setup']['experiment']}_best.pth" |
| save_path = save_dir / save_name |
|
|
| torch.save(model.state_dict(),save_path) |
|
|
| print(f"🌟 New Best Acc: {best_acc:.2f} -> Model save to: {save_path}") |
| |
|
|
| wandb.log({ |
| "train_epoch_loss":train_epoch_loss, |
| "train_epoch_acc":train_epoch_acc, |
| "test_epoch_loss":val_epoch_loss, |
| "test_epoch_acc":val_epoch_acc, |
| 'best_acc':best_acc, |
| "bad_cases":bad_cases, |
| "lr_back": lr_backbone, |
| "lr_head":lr_head, |
| "epoch": epoch, |
| }) |
| wandb.finish() |
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|
| |
|
|
|
|
|
|