tusharsangam's picture
Upload folder using huggingface_hub
9205b56 verified
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from dataset import Autoencoder_dataset
from model import Autoencoder
from torch.utils.tensorboard import SummaryWriter
import argparse
torch.autograd.set_detect_anomaly(True)
def l2_loss(network_output, gt):
return ((network_output - gt) ** 2).mean()
def cos_loss(network_output, gt):
return 1 - F.cosine_similarity(network_output, gt, dim=0).mean()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, required=True)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--encoder_dims',
nargs = '+',
type=int,
default=[256, 128, 64, 32, 3],
)
parser.add_argument('--decoder_dims',
nargs = '+',
type=int,
default=[16, 32, 64, 128, 256, 256, 512],
)
parser.add_argument('--dataset_name', type=str, required=True)
args = parser.parse_args()
dataset_path = args.dataset_path
num_epochs = args.num_epochs
data_dir = f"{dataset_path}/language_features"
os.makedirs(f'ckpt/{args.dataset_name}', exist_ok=True)
train_dataset = Autoencoder_dataset(data_dir)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True,
num_workers=16,
drop_last=False
)
test_loader = DataLoader(
dataset=train_dataset,
batch_size=256,
shuffle=False,
num_workers=16,
drop_last=False
)
encoder_hidden_dims = args.encoder_dims
decoder_hidden_dims = args.decoder_dims
model = Autoencoder(encoder_hidden_dims, decoder_hidden_dims).to("cuda:0")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
logdir = f'ckpt/{args.dataset_name}'
tb_writer = SummaryWriter(logdir)
best_eval_loss = 100.0
best_epoch = 0
for epoch in tqdm(range(num_epochs)):
model.train()
for idx, feature in enumerate(train_loader):
data = feature.to("cuda:0")
outputs_dim3 = model.encode(data)
outputs = model.decode(outputs_dim3)
l2loss = l2_loss(outputs, data)
cosloss = cos_loss(outputs, data)
loss = l2loss + cosloss * 0.001
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_iter = epoch * len(train_loader) + idx
tb_writer.add_scalar('train_loss/l2_loss', l2loss.item(), global_iter)
tb_writer.add_scalar('train_loss/cos_loss', cosloss.item(), global_iter)
tb_writer.add_scalar('train_loss/total_loss', loss.item(), global_iter)
tb_writer.add_histogram("feat", outputs, global_iter)
if epoch > 95:
eval_loss = 0.0
model.eval()
for idx, feature in enumerate(test_loader):
data = feature.to("cuda:0")
with torch.no_grad():
outputs = model(data)
loss = l2_loss(outputs, data) + cos_loss(outputs, data)
eval_loss += loss * len(feature)
eval_loss = eval_loss / len(train_dataset)
print("eval_loss:{:.8f}".format(eval_loss))
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
best_epoch = epoch
torch.save(model.state_dict(), f'ckpt/{args.dataset_name}/best_ckpt.pth')
if epoch % 10 == 0:
torch.save(model.state_dict(), f'ckpt/{args.dataset_name}/{epoch}_ckpt.pth')
print(f"best_epoch: {best_epoch}")
print("best_loss: {:.8f}".format(best_eval_loss))