import torch import torch.nn.functional as F from torch.optim import Adam from torch.utils.data import DataLoader from Diffusion.networks import get_net from Dataloader.dataLoader import * import argparse import yaml import os import time import swanlab parser = argparse.ArgumentParser() parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml") args = parser.parse_args() with open(args.config, 'r') as file: hyp = yaml.safe_load(file) # Setup device = torch.device(hyp['device'] if torch.cuda.is_available() else 'cpu') data_name = hyp['data_name'] net_name = hyp['net_name'] ndims = hyp['ndims'] img_size = hyp['img_size'] model_save_path = os.path.join('Models', f'{data_name}_{net_name}/') os.makedirs(model_save_path, exist_ok=True) # SwanLab swanlab.init(project="OM", config=hyp) # Model Net = get_net(net_name) model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device) optimizer = Adam(model.parameters(), lr=hyp['lr']) # Data dataset = OMDataset_indiv(out_sz=img_size, transform=None) train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True) # Training print('start training...') for epoch in range(hyp['epoch']): epoch_loss = 0.0 for i, (volume, embd) in enumerate(train_loader): t0 = time.time() volume = volume.float().to(device) embd = embd.to(device) # [B, 1024] GT text embedding t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device) _, img_embd = model(x=volume, y=volume, t=t) # img_embd: [B, 1024] # Cosine similarity loss: align img_embd with GT text embedding loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean() swanlab.log({"loss": loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() t1 = time.time() dt = t1 - t0 swanlab.log({"Time(mins)/batch": dt/60}) avg_loss = epoch_loss / max(len(train_loader), 1) print(f"Epoch {epoch:04d} | Loss: {avg_loss:.6f}") swanlab.log({"Avg Loss/epoch": avg_loss}) # if epoch % hyp['epoch_per_save'] == 0: # save_path = model_save_path + str(epoch).rjust(6, '0') + f'_{data_name}_{net_name}.pth' # torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, save_path) # print(f"Saved: {save_path}")