|
|
import torch |
|
|
torch.set_float32_matmul_precision('high') |
|
|
import os |
|
|
import yaml |
|
|
import wandb |
|
|
from torch import nn |
|
|
from pathlib import Path |
|
|
import sys |
|
|
from torch.amp import GradScaler |
|
|
|
|
|
os.environ["CXX"] = "/usr/bin/g++" |
|
|
os.environ["CC"] = "/usr/bin/gcc" |
|
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parent |
|
|
if ROOT_DIR not in sys.path: |
|
|
sys.path.append(str(ROOT_DIR)) |
|
|
|
|
|
from src.dataset import get_dataloader |
|
|
from src.utils import get_device,seed_everthing |
|
|
from src.model import ResNet18_CIFAR,SimpleCNN,TransferResNet50 |
|
|
from src.engine import train_one_epoch,evaluate |
|
|
|
|
|
def load_yaml(config_path=None): |
|
|
if config_path is None: |
|
|
config_path = ROOT_DIR / 'config.yaml' |
|
|
try: |
|
|
with open(config_path,'r',encoding='utf-8') as f: |
|
|
config = yaml.safe_load(f) |
|
|
return config |
|
|
except FileNotFoundError: |
|
|
print(f"{config_path} File not found!!") |
|
|
exit(1) |
|
|
|
|
|
def main(): |
|
|
static_config = load_yaml() |
|
|
wandb_cfg = static_config['wandb_setup'] |
|
|
|
|
|
wandb.init( |
|
|
project=wandb_cfg.get('project','my_project'), |
|
|
group=wandb_cfg.get('experiment','default'), |
|
|
tags=wandb_cfg.get('tags',[]), |
|
|
job_type=wandb_cfg.get('job_type','train'), |
|
|
config=static_config, |
|
|
) |
|
|
|
|
|
cfg = wandb.config |
|
|
|
|
|
relative_save_dir = cfg['train']['save_dir'] |
|
|
save_dir = (ROOT_DIR / relative_save_dir).resolve() |
|
|
os.makedirs(save_dir,exist_ok=True) |
|
|
|
|
|
best_acc = 0.0 |
|
|
print(f" Save dir: {save_dir}") |
|
|
|
|
|
print(f" Model: {cfg['model']['type']}") |
|
|
print(f"Experiment Start! Mode: {'Sweep' if wandb.run.sweep_id else 'Manual'}") |
|
|
print(f" Lr: {cfg['optimizer']['lr']}, Batch: {cfg['data']['batch_size']}, Opt: {cfg['optimizer']['name']}") |
|
|
|
|
|
seed_everthing(cfg.get('seed',42)) |
|
|
device = get_device() |
|
|
|
|
|
relative_data_path = cfg['data']['data_path'] |
|
|
|
|
|
absolute_data_path = (ROOT_DIR / relative_data_path).resolve() |
|
|
|
|
|
data_cfg = cfg['data'].copy() |
|
|
data_cfg['data_path'] = str(absolute_data_path) |
|
|
|
|
|
print(f'Loading data from {absolute_data_path}...') |
|
|
train_loader,test_loader = get_dataloader(data_cfg) |
|
|
|
|
|
|
|
|
dummy_x, dummy_y = next(iter(train_loader)) |
|
|
print(f"🧐 Inspection - Input Shape: {dummy_x.shape}") |
|
|
|
|
|
model_type = cfg['model']['type'] |
|
|
num_classes = cfg['model']['num_classes'] |
|
|
|
|
|
dropout_rate = cfg['model'].get('dropout_rate',0.0) |
|
|
num_inputs = cfg['model'].get('num_inputs',3) |
|
|
|
|
|
input_size = cfg['model'].get('input_size',32) |
|
|
|
|
|
if model_type == 'SimpleCNN': |
|
|
model = SimpleCNN( |
|
|
num_inputs = num_inputs, |
|
|
input_size = input_size, |
|
|
num_classes = num_classes, |
|
|
dropout_rate = dropout_rate, |
|
|
) |
|
|
elif model_type == 'ResNet18': |
|
|
model = ResNet18_CIFAR( |
|
|
num_inputs = num_inputs, |
|
|
num_classes = num_classes, |
|
|
dropout_rate = dropout_rate, |
|
|
) |
|
|
elif model_type == 'TransferResNet50': |
|
|
model = TransferResNet50( |
|
|
num_classes=num_classes, |
|
|
dropout_rate=dropout_rate, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
model.to(device) |
|
|
model = model.to(memory_format=torch.channels_last) |
|
|
|
|
|
if hasattr(model,'net'): |
|
|
print("⚡ Compiling ResNet backbone...") |
|
|
model.net = torch.compile(model.net,mode='reduce-overhead') |
|
|
else: |
|
|
print("⚡ Compiling Full Model...") |
|
|
model = torch.compile(model,mode='reduce-overhead') |
|
|
|
|
|
opt_cfg = cfg['optimizer'] |
|
|
opt_name = opt_cfg['name'].lower() |
|
|
|
|
|
|
|
|
lr_head = float(opt_cfg['lr']) |
|
|
lr_backbone = float(opt_cfg.get('backbone_lr', lr_head * 0.1)) |
|
|
weight_decay = float(opt_cfg.get('weight_decay', 0.0)) |
|
|
|
|
|
|
|
|
|
|
|
backbone_params = [] |
|
|
head_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if "fc" in name: |
|
|
head_params.append(param) |
|
|
else: |
|
|
backbone_params.append(param) |
|
|
|
|
|
print(f"🔧 Optimizer Setup: Head LR={lr_head}, Backbone LR={lr_backbone}") |
|
|
|
|
|
|
|
|
if opt_name == "adam": |
|
|
optimizer = torch.optim.Adam([ |
|
|
{'params': backbone_params, 'lr': lr_backbone}, |
|
|
{'params': head_params, 'lr': lr_head} |
|
|
], weight_decay=weight_decay) |
|
|
|
|
|
elif opt_name == "adamw": |
|
|
optimizer = torch.optim.AdamW([ |
|
|
{'params': backbone_params, 'lr': lr_backbone}, |
|
|
{'params': head_params, 'lr': lr_head} |
|
|
], weight_decay=weight_decay) |
|
|
|
|
|
elif opt_name == "sgd": |
|
|
optimizer = torch.optim.SGD([ |
|
|
{'params': backbone_params, 'lr': lr_backbone}, |
|
|
{'params': head_params, 'lr': lr_head} |
|
|
], momentum=0.9, weight_decay=weight_decay) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"不支持的优化器: {opt_name}") |
|
|
|
|
|
scheduler = None |
|
|
if 'scheduler' in cfg and cfg['scheduler'].get('use_scheduler',False): |
|
|
sch_cfg = cfg['scheduler'] |
|
|
|
|
|
if sch_cfg['type'] == 'CosineAnnealingLR': |
|
|
t_max = cfg['train']['epochs'] |
|
|
eta_min = float(sch_cfg.get('eta_min',0.0)) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
optimizer, |
|
|
T_max = t_max, |
|
|
eta_min = eta_min, |
|
|
|
|
|
) |
|
|
|
|
|
elif sch_cfg['type'] == 'StepLR': |
|
|
step_size = sch_cfg.get('step_size',10) |
|
|
gamma = sch_cfg.get('gamma',0.1) |
|
|
scheduler = torch.optim.lr_scheduler.StepLR( |
|
|
optimizer, |
|
|
step_size=step_size, |
|
|
gamma=gamma, |
|
|
) |
|
|
else: |
|
|
print('Not using Learning Rate Scheduler') |
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) |
|
|
|
|
|
epochs = cfg['train']['epochs'] |
|
|
scaler = GradScaler('cuda') |
|
|
|
|
|
for epoch in range(epochs): |
|
|
train_epoch_loss,train_epoch_acc = train_one_epoch(epoch,model,train_loader,loss_fn,optimizer,device,scaler) |
|
|
val_epoch_loss,val_epoch_acc,bad_cases = evaluate(epoch,model,test_loader,loss_fn,device) |
|
|
|
|
|
current_lr = optimizer.param_groups[0]['lr'] |
|
|
|
|
|
if scheduler is not None: |
|
|
scheduler.step() |
|
|
|
|
|
print(f"Epoch {epoch+1}/{epochs}\t[LR: {current_lr:>.6f}]\tTrain Loss: {train_epoch_loss:>.3f}\tTrain Acc: {train_epoch_acc:>.2%}\t|\tVal Loss: {val_epoch_loss:>.3f}\tVal Acc: {val_epoch_acc:>.2%}") |
|
|
|
|
|
|
|
|
if val_epoch_acc > best_acc: |
|
|
best_acc = val_epoch_acc |
|
|
save_name = f"{cfg['wandb_setup']['experiment']}_best.pth" |
|
|
save_path = save_dir / save_name |
|
|
|
|
|
torch.save(model.state_dict(),save_path) |
|
|
|
|
|
print(f"🌟 New Best Acc: {best_acc:.2f} -> Model save to: {save_path}") |
|
|
|
|
|
|
|
|
wandb.log({ |
|
|
"train_epoch_loss":train_epoch_loss, |
|
|
"train_epoch_acc":train_epoch_acc, |
|
|
"test_epoch_loss":val_epoch_loss, |
|
|
"test_epoch_acc":val_epoch_acc, |
|
|
'best_acc':best_acc, |
|
|
"bad_cases":bad_cases, |
|
|
"learning_rate": current_lr, |
|
|
"epoch": epoch, |
|
|
}) |
|
|
wandb.finish() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|