|
|
import os |
|
|
import random |
|
|
import json |
|
|
import torch |
|
|
import pprint |
|
|
import collections |
|
|
import numpy as np |
|
|
from torch import nn |
|
|
from tensorboardX import SummaryWriter |
|
|
from tqdm import trange |
|
|
|
|
|
class Module(nn.Module): |
|
|
|
|
|
def __init__(self, args, vocab): |
|
|
''' |
|
|
Base Seq2Seq agent with common train and val loops |
|
|
''' |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.pad = 0 |
|
|
self.seg = 1 |
|
|
|
|
|
|
|
|
self.args = args |
|
|
self.vocab = vocab |
|
|
|
|
|
|
|
|
self.emb_word = nn.Embedding(len(vocab['word']), args.demb) |
|
|
self.emb_action_low = nn.Embedding(len(vocab['action_low']), args.demb) |
|
|
|
|
|
|
|
|
self.stop_token = self.vocab['action_low'].word2index("<<stop>>", train=False) |
|
|
self.seg_token = self.vocab['action_low'].word2index("<<seg>>", train=False) |
|
|
|
|
|
|
|
|
random.seed(a=args.seed) |
|
|
|
|
|
|
|
|
self.summary_writer = None |
|
|
|
|
|
def run_train(self, splits, args=None, optimizer=None): |
|
|
''' |
|
|
training loop |
|
|
''' |
|
|
|
|
|
|
|
|
args = args or self.args |
|
|
|
|
|
|
|
|
train = splits['train'] |
|
|
valid_seen = splits['valid_seen'] |
|
|
valid_unseen = splits['valid_unseen'] |
|
|
|
|
|
|
|
|
if self.args.dataset_fraction > 0: |
|
|
small_train_size = int(self.args.dataset_fraction * 0.7) |
|
|
small_valid_size = int((self.args.dataset_fraction * 0.3) / 2) |
|
|
train = train[:small_train_size] |
|
|
valid_seen = valid_seen[:small_valid_size] |
|
|
valid_unseen = valid_unseen[:small_valid_size] |
|
|
|
|
|
|
|
|
if self.args.fast_epoch: |
|
|
train = train[:16] |
|
|
valid_seen = valid_seen[:16] |
|
|
valid_unseen = valid_unseen[:16] |
|
|
|
|
|
|
|
|
self.summary_writer = SummaryWriter(log_dir=args.dout) |
|
|
|
|
|
|
|
|
fconfig = os.path.join(args.dout, 'config.json') |
|
|
with open(fconfig, 'wt') as f: |
|
|
json.dump(vars(args), f, indent=2) |
|
|
|
|
|
|
|
|
optimizer = optimizer or torch.optim.Adam(self.parameters(), lr=args.lr) |
|
|
|
|
|
|
|
|
print("Saving to: %s" % self.args.dout) |
|
|
best_loss = {'train': 1e10, 'valid_seen': 1e10, 'valid_unseen': 1e10} |
|
|
train_iter, valid_seen_iter, valid_unseen_iter = 0, 0, 0 |
|
|
for epoch in trange(0, args.epoch, desc='epoch'): |
|
|
m_train = collections.defaultdict(list) |
|
|
self.train() |
|
|
self.adjust_lr(optimizer, args.lr, epoch, decay_epoch=args.decay_epoch) |
|
|
|
|
|
total_train_loss = list() |
|
|
random.shuffle(train) |
|
|
for batch, feat in self.iterate(train, args.batch): |
|
|
out = self.forward(feat) |
|
|
preds = self.extract_preds(out, batch, feat) |
|
|
|
|
|
loss = self.compute_loss(out, batch, feat) |
|
|
for k, v in loss.items(): |
|
|
ln = 'loss_' + k |
|
|
m_train[ln].append(v.item()) |
|
|
self.summary_writer.add_scalar('train/' + ln, v.item(), train_iter) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
sum_loss = sum(loss.values()) |
|
|
sum_loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
self.summary_writer.add_scalar('train/loss', sum_loss, train_iter) |
|
|
sum_loss = sum_loss.detach().cpu() |
|
|
total_train_loss.append(float(sum_loss)) |
|
|
train_iter += self.args.batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p_valid_seen, valid_seen_iter, total_valid_seen_loss, m_valid_seen = self.run_pred(valid_seen, args=args, name='valid_seen', iter=valid_seen_iter) |
|
|
m_valid_seen.update(self.compute_metric(p_valid_seen, valid_seen)) |
|
|
m_valid_seen['total_loss'] = float(total_valid_seen_loss) |
|
|
self.summary_writer.add_scalar('valid_seen/total_loss', m_valid_seen['total_loss'], valid_seen_iter) |
|
|
|
|
|
|
|
|
p_valid_unseen, valid_unseen_iter, total_valid_unseen_loss, m_valid_unseen = self.run_pred(valid_unseen, args=args, name='valid_unseen', iter=valid_unseen_iter) |
|
|
m_valid_unseen.update(self.compute_metric(p_valid_unseen, valid_unseen)) |
|
|
m_valid_unseen['total_loss'] = float(total_valid_unseen_loss) |
|
|
self.summary_writer.add_scalar('valid_unseen/total_loss', m_valid_unseen['total_loss'], valid_unseen_iter) |
|
|
|
|
|
stats = {'epoch': epoch, |
|
|
'valid_seen': m_valid_seen, |
|
|
'valid_unseen': m_valid_unseen} |
|
|
|
|
|
|
|
|
if total_valid_seen_loss < best_loss['valid_seen']: |
|
|
print('\nFound new best valid_seen!! Saving...') |
|
|
fsave = os.path.join(args.dout, 'best_seen.pth') |
|
|
torch.save({ |
|
|
'metric': stats, |
|
|
'model': self.state_dict(), |
|
|
'optim': optimizer.state_dict(), |
|
|
'args': self.args, |
|
|
'vocab': self.vocab, |
|
|
}, fsave) |
|
|
fbest = os.path.join(args.dout, 'best_seen.json') |
|
|
with open(fbest, 'wt') as f: |
|
|
json.dump(stats, f, indent=2) |
|
|
|
|
|
fpred = os.path.join(args.dout, 'valid_seen.debug.preds.json') |
|
|
with open(fpred, 'wt') as f: |
|
|
json.dump(self.make_debug(p_valid_seen, valid_seen), f, indent=2) |
|
|
best_loss['valid_seen'] = total_valid_seen_loss |
|
|
|
|
|
|
|
|
if total_valid_unseen_loss < best_loss['valid_unseen']: |
|
|
print('Found new best valid_unseen!! Saving...') |
|
|
fsave = os.path.join(args.dout, 'best_unseen.pth') |
|
|
torch.save({ |
|
|
'metric': stats, |
|
|
'model': self.state_dict(), |
|
|
'optim': optimizer.state_dict(), |
|
|
'args': self.args, |
|
|
'vocab': self.vocab, |
|
|
}, fsave) |
|
|
fbest = os.path.join(args.dout, 'best_unseen.json') |
|
|
with open(fbest, 'wt') as f: |
|
|
json.dump(stats, f, indent=2) |
|
|
|
|
|
fpred = os.path.join(args.dout, 'valid_unseen.debug.preds.json') |
|
|
with open(fpred, 'wt') as f: |
|
|
json.dump(self.make_debug(p_valid_unseen, valid_unseen), f, indent=2) |
|
|
|
|
|
best_loss['valid_unseen'] = total_valid_unseen_loss |
|
|
|
|
|
|
|
|
if args.save_every_epoch: |
|
|
fsave = os.path.join(args.dout, 'net_epoch_%d.pth' % epoch) |
|
|
else: |
|
|
fsave = os.path.join(args.dout, 'latest.pth') |
|
|
torch.save({ |
|
|
'metric': stats, |
|
|
'model': self.state_dict(), |
|
|
'optim': optimizer.state_dict(), |
|
|
'args': self.args, |
|
|
'vocab': self.vocab, |
|
|
}, fsave) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for split in stats.keys(): |
|
|
if isinstance(stats[split], dict): |
|
|
for k, v in stats[split].items(): |
|
|
self.summary_writer.add_scalar(split + '/' + k, v, train_iter) |
|
|
pprint.pprint(stats) |
|
|
|
|
|
def run_pred(self, dev, args=None, name='dev', iter=0): |
|
|
''' |
|
|
validation loop |
|
|
''' |
|
|
args = args or self.args |
|
|
m_dev = collections.defaultdict(list) |
|
|
p_dev = {} |
|
|
self.eval() |
|
|
total_loss = list() |
|
|
dev_iter = iter |
|
|
for batch, feat in self.iterate(dev, args.batch): |
|
|
out = self.forward(feat) |
|
|
preds = self.extract_preds(out, batch, feat) |
|
|
p_dev.update(preds) |
|
|
loss = self.compute_loss(out, batch, feat) |
|
|
for k, v in loss.items(): |
|
|
ln = 'loss_' + k |
|
|
m_dev[ln].append(v.item()) |
|
|
self.summary_writer.add_scalar("%s/%s" % (name, ln), v.item(), dev_iter) |
|
|
sum_loss = sum(loss.values()) |
|
|
self.summary_writer.add_scalar("%s/loss" % (name), sum_loss, dev_iter) |
|
|
total_loss.append(float(sum_loss.detach().cpu())) |
|
|
dev_iter += len(batch) |
|
|
|
|
|
m_dev = {k: sum(v) / len(v) for k, v in m_dev.items()} |
|
|
total_loss = sum(total_loss) / len(total_loss) |
|
|
return p_dev, dev_iter, total_loss, m_dev |
|
|
|
|
|
def featurize(self, batch): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def forward(self, feat, max_decode=100): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def extract_preds(self, out, batch, feat): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def compute_loss(self, out, batch, feat): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def compute_metric(self, preds, data): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def get_task_and_ann_id(self, ex): |
|
|
''' |
|
|
single string for task_id and annotation repeat idx |
|
|
''' |
|
|
return "%s_%s" % (ex['task_id'], str(ex['ann']['repeat_idx'])) |
|
|
|
|
|
def make_debug(self, preds, data): |
|
|
''' |
|
|
readable output generator for debugging |
|
|
''' |
|
|
debug = {} |
|
|
for task in data: |
|
|
ex = self.load_task_json(task) |
|
|
i = self.get_task_and_ann_id(ex) |
|
|
debug[i] = { |
|
|
'lang_goal': ex['turk_annotations']['anns'][ex['ann']['repeat_idx']]['task_desc'], |
|
|
'action_low': [a['discrete_action']['action'] for a in ex['plan']['low_actions']], |
|
|
'p_action_low': preds[i]['action_low'].split(), |
|
|
} |
|
|
return debug |
|
|
|
|
|
def load_task_json(self, task): |
|
|
''' |
|
|
load preprocessed json from disk |
|
|
''' |
|
|
json_path = os.path.join(self.args.data, task['task'], '%s' % self.args.pp_folder, 'ann_%d.json' % task['repeat_idx']) |
|
|
with open(json_path) as f: |
|
|
data = json.load(f) |
|
|
return data |
|
|
|
|
|
def get_task_root(self, ex): |
|
|
''' |
|
|
returns the folder path of a trajectory |
|
|
''' |
|
|
return os.path.join(self.args.data, ex['split'], *(ex['root'].split('/')[-2:])) |
|
|
|
|
|
def iterate(self, data, batch_size): |
|
|
''' |
|
|
breaks dataset into batch_size chunks for training |
|
|
''' |
|
|
for i in trange(0, len(data), batch_size, desc='batch'): |
|
|
tasks = data[i:i+batch_size] |
|
|
batch = [self.load_task_json(task) for task in tasks] |
|
|
feat = self.featurize(batch) |
|
|
yield batch, feat |
|
|
|
|
|
def zero_input(self, x, keep_end_token=True): |
|
|
''' |
|
|
pad input with zeros (used for ablations) |
|
|
''' |
|
|
end_token = [x[-1]] if keep_end_token else [self.pad] |
|
|
return list(np.full_like(x[:-1], self.pad)) + end_token |
|
|
|
|
|
def zero_input_list(self, x, keep_end_token=True): |
|
|
''' |
|
|
pad a list of input with zeros (used for ablations) |
|
|
''' |
|
|
end_token = [x[-1]] if keep_end_token else [self.pad] |
|
|
lz = [list(np.full_like(i, self.pad)) for i in x[:-1]] + end_token |
|
|
return lz |
|
|
|
|
|
@staticmethod |
|
|
def adjust_lr(optimizer, init_lr, epoch, decay_epoch=5): |
|
|
''' |
|
|
decay learning rate every decay_epoch |
|
|
''' |
|
|
lr = init_lr * (0.1 ** (epoch // decay_epoch)) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
@classmethod |
|
|
def load(cls, fsave): |
|
|
''' |
|
|
load pth model from disk |
|
|
''' |
|
|
save = torch.load(fsave) |
|
|
model = cls(save['args'], save['vocab']) |
|
|
model.load_state_dict(save['model']) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
optimizer.load_state_dict(save['optim']) |
|
|
return model, optimizer |
|
|
|
|
|
@classmethod |
|
|
def has_interaction(cls, action): |
|
|
''' |
|
|
check if low-level action is interactive |
|
|
''' |
|
|
non_interact_actions = ['MoveAhead', 'Rotate', 'Look', '<<stop>>', '<<pad>>', '<<seg>>'] |
|
|
if any(a in action for a in non_interact_actions): |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|