| import torch |
| import torch.nn.functional as F |
| from torch.optim import Adam |
| from torch.utils.data import DataLoader |
| from Diffusion.networks import get_net |
| from Dataloader.dataLoader import * |
| import argparse |
| import yaml |
| import os |
| import time |
| import swanlab |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml") |
| args = parser.parse_args() |
|
|
| with open(args.config, 'r') as file: |
| hyp = yaml.safe_load(file) |
|
|
| |
| device = torch.device(hyp['device'] if torch.cuda.is_available() else 'cpu') |
| data_name = hyp['data_name'] |
| net_name = hyp['net_name'] |
| ndims = hyp['ndims'] |
| img_size = hyp['img_size'] |
| model_save_path = os.path.join('Models', f'{data_name}_{net_name}/') |
| os.makedirs(model_save_path, exist_ok=True) |
|
|
| |
| swanlab.init(project="OM", config=hyp) |
|
|
| |
| Net = get_net(net_name) |
| model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device) |
| optimizer = Adam(model.parameters(), lr=hyp['lr']) |
|
|
| |
| dataset = OMDataset_indiv(out_sz=img_size, transform=None) |
| train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True) |
|
|
| |
| print('start training...') |
| for epoch in range(hyp['epoch']): |
| epoch_loss = 0.0 |
| |
| for i, (volume, embd) in enumerate(train_loader): |
| t0 = time.time() |
| volume = volume.float().to(device) |
| embd = embd.to(device) |
| t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device) |
|
|
| _, img_embd = model(x=volume, y=volume, t=t) |
|
|
| |
| loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean() |
| swanlab.log({"loss": loss.item()}) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| epoch_loss += loss.item() |
| t1 = time.time() |
| dt = t1 - t0 |
| swanlab.log({"Time(mins)/batch": dt/60}) |
| avg_loss = epoch_loss / max(len(train_loader), 1) |
| print(f"Epoch {epoch:04d} | Loss: {avg_loss:.6f}") |
| swanlab.log({"Avg Loss/epoch": avg_loss}) |
| |
| |
| |
| |
| |
|
|