| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from tqdm import tqdm | |
| from dataset import Autoencoder_dataset | |
| from model import Autoencoder | |
| from torch.utils.tensorboard import SummaryWriter | |
| import argparse | |
| torch.autograd.set_detect_anomaly(True) | |
| def l2_loss(network_output, gt): | |
| return ((network_output - gt) ** 2).mean() | |
| def cos_loss(network_output, gt): | |
| return 1 - F.cosine_similarity(network_output, gt, dim=0).mean() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dataset_path', type=str, required=True) | |
| parser.add_argument('--num_epochs', type=int, default=100) | |
| parser.add_argument('--lr', type=float, default=0.0001) | |
| parser.add_argument('--encoder_dims', | |
| nargs = '+', | |
| type=int, | |
| default=[256, 128, 64, 32, 3], | |
| ) | |
| parser.add_argument('--decoder_dims', | |
| nargs = '+', | |
| type=int, | |
| default=[16, 32, 64, 128, 256, 256, 512], | |
| ) | |
| parser.add_argument('--dataset_name', type=str, required=True) | |
| args = parser.parse_args() | |
| dataset_path = args.dataset_path | |
| num_epochs = args.num_epochs | |
| data_dir = f"{dataset_path}/language_features" | |
| os.makedirs(f'ckpt/{args.dataset_name}', exist_ok=True) | |
| train_dataset = Autoencoder_dataset(data_dir) | |
| train_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=64, | |
| shuffle=True, | |
| num_workers=16, | |
| drop_last=False | |
| ) | |
| test_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=256, | |
| shuffle=False, | |
| num_workers=16, | |
| drop_last=False | |
| ) | |
| encoder_hidden_dims = args.encoder_dims | |
| decoder_hidden_dims = args.decoder_dims | |
| model = Autoencoder(encoder_hidden_dims, decoder_hidden_dims).to("cuda:0") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | |
| logdir = f'ckpt/{args.dataset_name}' | |
| tb_writer = SummaryWriter(logdir) | |
| best_eval_loss = 100.0 | |
| best_epoch = 0 | |
| for epoch in tqdm(range(num_epochs)): | |
| model.train() | |
| for idx, feature in enumerate(train_loader): | |
| data = feature.to("cuda:0") | |
| outputs_dim3 = model.encode(data) | |
| outputs = model.decode(outputs_dim3) | |
| l2loss = l2_loss(outputs, data) | |
| cosloss = cos_loss(outputs, data) | |
| loss = l2loss + cosloss * 0.001 | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| global_iter = epoch * len(train_loader) + idx | |
| tb_writer.add_scalar('train_loss/l2_loss', l2loss.item(), global_iter) | |
| tb_writer.add_scalar('train_loss/cos_loss', cosloss.item(), global_iter) | |
| tb_writer.add_scalar('train_loss/total_loss', loss.item(), global_iter) | |
| tb_writer.add_histogram("feat", outputs, global_iter) | |
| if epoch > 95: | |
| eval_loss = 0.0 | |
| model.eval() | |
| for idx, feature in enumerate(test_loader): | |
| data = feature.to("cuda:0") | |
| with torch.no_grad(): | |
| outputs = model(data) | |
| loss = l2_loss(outputs, data) + cos_loss(outputs, data) | |
| eval_loss += loss * len(feature) | |
| eval_loss = eval_loss / len(train_dataset) | |
| print("eval_loss:{:.8f}".format(eval_loss)) | |
| if eval_loss < best_eval_loss: | |
| best_eval_loss = eval_loss | |
| best_epoch = epoch | |
| torch.save(model.state_dict(), f'ckpt/{args.dataset_name}/best_ckpt.pth') | |
| if epoch % 10 == 0: | |
| torch.save(model.state_dict(), f'ckpt/{args.dataset_name}/{epoch}_ckpt.pth') | |
| print(f"best_epoch: {best_epoch}") | |
| print("best_loss: {:.8f}".format(best_eval_loss)) |