| import os |
| import numpy as np |
|
|
| import torch |
| import torch.nn.functional as F |
| from pytorch_lightning import LightningModule |
|
|
| from cliport.tasks import cameras |
| from cliport.utils import utils |
| from cliport.models.core.attention import Attention |
| from cliport.models.core.transport import Transport |
| from cliport.models.streams.two_stream_attention import TwoStreamAttention |
| from cliport.models.streams.two_stream_transport import TwoStreamTransport |
|
|
| from cliport.models.streams.two_stream_attention import TwoStreamAttentionLat |
| from cliport.models.streams.two_stream_transport import TwoStreamTransportLat |
| import time |
| import IPython |
|
|
| class TransporterAgent(LightningModule): |
| def __init__(self, name, cfg, train_ds, test_ds): |
| super().__init__() |
| utils.set_seed(0) |
| self.automatic_optimization=False |
| self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.name = name |
| self.cfg = cfg |
| self.train_loader = train_ds |
| self.test_loader = test_ds |
|
|
| self.train_ds = train_ds.dataset |
| self.test_ds = test_ds.dataset |
|
|
| self.name = name |
| self.task = cfg['train']['task'] |
| self.total_steps = 0 |
| self.crop_size = 64 |
| self.n_rotations = cfg['train']['n_rotations'] |
|
|
| self.pix_size = 0.003125 |
| self.in_shape = (320, 160, 6) |
| self.cam_config = cameras.RealSenseD415.CONFIG |
| self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) |
|
|
| self.val_repeats = cfg['train']['val_repeats'] |
| self.save_steps = cfg['train']['save_steps'] |
|
|
| self._build_model() |
| |
| |
| |
| self._optimizers = { |
| 'attn': torch.optim.Adam(self.attention.parameters(), lr=self.cfg['train']['lr']), |
| 'trans': torch.optim.Adam(self.transport.parameters(), lr=self.cfg['train']['lr']) |
| } |
| print("Agent: {}, Logging: {}".format(name, cfg['train']['log'])) |
|
|
| def configure_optimizers(self): |
| return self._optimizers |
|
|
| def _build_model(self): |
| self.attention = None |
| self.transport = None |
| raise NotImplementedError() |
|
|
| def forward(self, x): |
| raise NotImplementedError() |
|
|
| def cross_entropy_with_logits(self, pred, labels, reduction='mean'): |
| |
| x = (-labels.view(len(labels), -1) * F.log_softmax(pred.view(len(labels), -1), -1)) |
| if reduction == 'sum': |
| return x.sum() |
| elif reduction == 'mean': |
| return x.mean() |
| else: |
| raise NotImplementedError() |
|
|
| def attn_forward(self, inp, softmax=True): |
| inp_img = inp['inp_img'] |
| output = self.attention.forward(inp_img, softmax=softmax) |
| return output |
|
|
| def attn_training_step(self, frame, backprop=True, compute_err=False): |
| inp_img = frame['img'] |
| p0, p0_theta = frame['p0'], frame['p0_theta'] |
|
|
| inp = {'inp_img': inp_img} |
| out = self.attn_forward(inp, softmax=False) |
| return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta) |
|
|
| def attn_criterion(self, backprop, compute_err, inp, out, p, theta): |
| |
| if type(theta) is torch.Tensor: |
| theta = theta.detach().cpu().numpy() |
| |
| theta_i = theta / (2 * np.pi / self.attention.n_rotations) |
| theta_i = np.int32(np.round(theta_i)) % self.attention.n_rotations |
| inp_img = inp['inp_img'].float() |
|
|
| label_size = inp_img.shape[:3] + (self.attention.n_rotations,) |
| label = torch.zeros(label_size, dtype=torch.float, device=out.device) |
|
|
| |
| for idx, p_i in enumerate(p): |
| label[idx, int(p_i[0]), int(p_i[1]), theta_i[idx]] = 1 |
| label = label.permute((0, 3, 1, 2)).contiguous() |
|
|
| |
| loss = self.cross_entropy_with_logits(out, label) |
|
|
| |
| if backprop: |
| attn_optim = self._optimizers['attn'] |
| self.manual_backward(loss) |
| attn_optim.step() |
| attn_optim.zero_grad() |
|
|
| |
| err = {} |
| if compute_err: |
| with torch.no_grad(): |
| pick_conf = self.attn_forward(inp) |
| pick_conf = pick_conf[0].permute(1,2,0) |
| pick_conf = pick_conf.detach().cpu().numpy() |
| p = p[0] |
| theta = theta[0] |
|
|
| |
| argmax = np.argmax(pick_conf) |
| argmax = np.unravel_index(argmax, shape=pick_conf.shape) |
| p0_pix = argmax[:2] |
| p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) |
|
|
| err = { |
| 'dist': np.linalg.norm(np.array(p.detach().cpu().numpy()) - p0_pix, ord=1), |
| 'theta': np.absolute((theta - p0_theta) % np.pi) |
| } |
| return loss, err |
|
|
| def trans_forward(self, inp, softmax=True): |
| inp_img = inp['inp_img'] |
| p0 = inp['p0'] |
|
|
| output = self.transport.forward(inp_img, p0, softmax=softmax) |
| return output |
|
|
| def transport_training_step(self, frame, backprop=True, compute_err=False): |
| inp_img = frame['img'].float() |
| p0 = frame['p0'] |
| p1, p1_theta = frame['p1'], frame['p1_theta'] |
|
|
| inp = {'inp_img': inp_img, 'p0': p0} |
| output = self.trans_forward(inp, softmax=False) |
| err, loss = self.transport_criterion(backprop, compute_err, inp, output, p0, p1, p1_theta) |
| return loss, err |
|
|
| def transport_criterion(self, backprop, compute_err, inp, output, p, q, theta): |
| s = time.time() |
| if type(theta) is torch.Tensor: |
| theta = theta.detach().cpu().numpy() |
|
|
| itheta = theta / (2 * np.pi / self.transport.n_rotations) |
| itheta = np.int32(np.round(itheta)) % self.transport.n_rotations |
|
|
| |
| inp_img = inp['inp_img'] |
|
|
| |
| label_size = inp_img.shape[:3] + (self.transport.n_rotations,) |
| label = torch.zeros(label_size, dtype=torch.float, device=output.device) |
|
|
| |
| q[:,0] = torch.clamp(q[:,0], 0, label.shape[1]-1) |
| q[:,1] = torch.clamp(q[:,1], 0, label.shape[2]-1) |
|
|
| for idx, q_i in enumerate(q): |
| label[idx, int(q_i[0]), int(q_i[1]), itheta[idx]] = 1 |
| label = label.permute((0, 3, 1, 2)).contiguous() |
|
|
| |
| loss = self.cross_entropy_with_logits(output, label) |
|
|
| if backprop: |
| transport_optim = self._optimizers['trans'] |
| transport_optim.zero_grad() |
| self.manual_backward(loss) |
| transport_optim.step() |
|
|
| |
| err = {} |
| if compute_err: |
| with torch.no_grad(): |
| place_conf = self.trans_forward(inp) |
| |
| place_conf = place_conf[0] |
| q = q[0] |
| theta = theta[0] |
| place_conf = place_conf.permute(1, 2, 0) |
| place_conf = place_conf.detach().cpu().numpy() |
| argmax = np.argmax(place_conf) |
| argmax = np.unravel_index(argmax, shape=place_conf.shape) |
| p1_pix = argmax[:2] |
| p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) |
|
|
| err = { |
| 'dist': np.linalg.norm(np.array(q.detach().cpu().numpy()) - p1_pix, ord=1), |
| 'theta': np.absolute((theta - p1_theta) % np.pi) |
| } |
|
|
| self.transport.iters += 1 |
| return err, loss |
|
|
| def training_step(self, batch, batch_idx): |
|
|
| self.attention.train() |
| self.transport.train() |
|
|
| frame, _ = batch |
| self.start_time = time.time() |
|
|
| |
| step = self.total_steps + 1 |
| loss0, err0 = self.attn_training_step(frame) |
|
|
| self.start_time = time.time() |
|
|
| if isinstance(self.transport, Attention): |
| loss1, err1 = self.attn_training_step(frame) |
| else: |
| loss1, err1 = self.transport_training_step(frame) |
|
|
| total_loss = loss0 + loss1 |
| self.total_steps = step |
| self.start_time = time.time() |
| self.log('tr/attn/loss', loss0) |
| self.log('tr/trans/loss', loss1) |
| self.log('tr/loss', total_loss) |
| self.check_save_iteration() |
|
|
| return dict( |
| loss=total_loss, |
| ) |
|
|
| def check_save_iteration(self): |
| global_step = self.total_steps |
|
|
| if (global_step + 1) % 100 == 0: |
| |
| print(f"Saving last.ckpt Epoch: {self.trainer.current_epoch} | Global Step: {self.trainer.global_step}") |
| self.save_last_checkpoint() |
|
|
| def save_last_checkpoint(self): |
| checkpoint_path = os.path.join(self.cfg['train']['train_dir'], 'checkpoints') |
| ckpt_path = os.path.join(checkpoint_path, 'last.ckpt') |
| self.trainer.save_checkpoint(ckpt_path) |
|
|
| def validation_step(self, batch, batch_idx): |
| self.attention.eval() |
| self.transport.eval() |
|
|
| loss0, loss1 = 0, 0 |
| assert self.val_repeats >= 1 |
| for i in range(self.val_repeats): |
| frame, _ = batch |
| l0, err0 = self.attn_training_step(frame, backprop=False, compute_err=True) |
| loss0 += l0 |
| if isinstance(self.transport, Attention): |
| l1, err1 = self.attn_training_step(frame, backprop=False, compute_err=True) |
| loss1 += l1 |
| else: |
| l1, err1 = self.transport_training_step(frame, backprop=False, compute_err=True) |
| loss1 += l1 |
| loss0 /= self.val_repeats |
| loss1 /= self.val_repeats |
| val_total_loss = loss0 + loss1 |
|
|
| return dict( |
| val_loss=val_total_loss, |
| val_loss0=loss0, |
| val_loss1=loss1, |
| val_attn_dist_err=err0['dist'], |
| val_attn_theta_err=err0['theta'], |
| val_trans_dist_err=err1['dist'], |
| val_trans_theta_err=err1['theta'], |
| ) |
|
|
| def training_epoch_end(self, all_outputs): |
| super().training_epoch_end(all_outputs) |
| utils.set_seed(self.trainer.current_epoch+1) |
|
|
| def validation_epoch_end(self, all_outputs): |
| mean_val_total_loss = np.mean([v['val_loss'].item() for v in all_outputs]) |
| mean_val_loss0 = np.mean([v['val_loss0'].item() for v in all_outputs]) |
| mean_val_loss1 = np.mean([v['val_loss1'].item() for v in all_outputs]) |
| total_attn_dist_err = np.sum([v['val_attn_dist_err'].sum() for v in all_outputs]) |
| total_attn_theta_err = np.sum([v['val_attn_theta_err'].sum() for v in all_outputs]) |
| total_trans_dist_err = np.sum([v['val_trans_dist_err'].sum() for v in all_outputs]) |
| total_trans_theta_err = np.sum([v['val_trans_theta_err'].sum() for v in all_outputs]) |
| |
|
|
| self.log('vl/attn/loss', mean_val_loss0) |
| self.log('vl/trans/loss', mean_val_loss1) |
| self.log('vl/loss', mean_val_total_loss) |
| self.log('vl/total_attn_dist_err', total_attn_dist_err) |
| self.log('vl/total_attn_theta_err', total_attn_theta_err) |
| self.log('vl/total_trans_dist_err', total_trans_dist_err) |
| self.log('vl/total_trans_theta_err', total_trans_theta_err) |
|
|
| print("\nAttn Err - Dist: {:.2f}, Theta: {:.2f}".format(total_attn_dist_err, total_attn_theta_err)) |
| print("Transport Err - Dist: {:.2f}, Theta: {:.2f}".format(total_trans_dist_err, total_trans_theta_err)) |
|
|
| return dict( |
| val_loss=mean_val_total_loss, |
| val_loss0=mean_val_loss0, |
| mean_val_loss1=mean_val_loss1, |
| total_attn_dist_err=total_attn_dist_err, |
| total_attn_theta_err=total_attn_theta_err, |
| total_trans_dist_err=total_trans_dist_err, |
| total_trans_theta_err=total_trans_theta_err, |
| ) |
|
|
| def act(self, obs, info=None, goal=None): |
| """Run inference and return best action given visual observations.""" |
| |
| img = self.test_ds.get_image(obs) |
|
|
| |
| pick_inp = {'inp_img': img} |
| pick_conf = self.attn_forward(pick_inp) |
| |
| |
| pick_conf = pick_conf.detach().cpu().numpy() |
| argmax = np.argmax(pick_conf) |
| argmax = np.unravel_index(argmax, shape=pick_conf.shape) |
| p0_pix = argmax[:2] |
| p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) |
|
|
| |
| place_inp = {'inp_img': img, 'p0': p0_pix} |
| place_conf = self.trans_forward(place_inp) |
| place_conf = place_conf.permute(1, 2, 0) |
| place_conf = place_conf.detach().cpu().numpy() |
| argmax = np.argmax(place_conf) |
| argmax = np.unravel_index(argmax, shape=place_conf.shape) |
| p1_pix = argmax[:2] |
| p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) |
|
|
| |
| hmap = img[:, :, 3] |
| p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size) |
| p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size) |
| p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta)) |
| p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta)) |
|
|
| return { |
| 'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)), |
| 'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)), |
| 'pick': p0_pix, |
| 'place': p1_pix, |
| } |
|
|
| def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs): |
| pass |
|
|
| def configure_optimizers(self): |
| pass |
|
|
| def train_dataloader(self): |
| return self.train_loader |
|
|
| def val_dataloader(self): |
| return self.test_loader |
|
|
| def load(self, model_path): |
| self.load_state_dict(torch.load(model_path)['state_dict']) |
| self.to(device=self.device_type) |
|
|
|
|
| class OriginalTransporterAgent(TransporterAgent): |
|
|
| def __init__(self, name, cfg, train_ds, test_ds): |
| super().__init__(name, cfg, train_ds, test_ds) |
|
|
| def _build_model(self): |
| stream_fcn = 'plain_resnet' |
| self.attention = Attention( |
| stream_fcn=(stream_fcn, None), |
| in_shape=self.in_shape, |
| n_rotations=1, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
| self.transport = Transport( |
| stream_fcn=(stream_fcn, None), |
| in_shape=self.in_shape, |
| n_rotations=self.n_rotations, |
| crop_size=self.crop_size, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
|
|
|
|
| class ClipUNetTransporterAgent(TransporterAgent): |
|
|
| def __init__(self, name, cfg, train_ds, test_ds): |
| super().__init__(name, cfg, train_ds, test_ds) |
|
|
| def _build_model(self): |
| stream_fcn = 'clip_unet' |
| self.attention = Attention( |
| stream_fcn=(stream_fcn, None), |
| in_shape=self.in_shape, |
| n_rotations=1, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
| self.transport = Transport( |
| stream_fcn=(stream_fcn, None), |
| in_shape=self.in_shape, |
| n_rotations=self.n_rotations, |
| crop_size=self.crop_size, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
|
|
|
|
| class TwoStreamClipUNetTransporterAgent(TransporterAgent): |
|
|
| def __init__(self, name, cfg, train_ds, test_ds): |
| super().__init__(name, cfg, train_ds, test_ds) |
|
|
| def _build_model(self): |
| stream_one_fcn = 'plain_resnet' |
| stream_two_fcn = 'clip_unet' |
| self.attention = TwoStreamAttention( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=1, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
| self.transport = TwoStreamTransport( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=self.n_rotations, |
| crop_size=self.crop_size, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
|
|
|
|
| class TwoStreamClipUNetLatTransporterAgent(TransporterAgent): |
|
|
| def __init__(self, name, cfg, train_ds, test_ds): |
| super().__init__(name, cfg, train_ds, test_ds) |
|
|
| def _build_model(self): |
| stream_one_fcn = 'plain_resnet_lat' |
| stream_two_fcn = 'clip_unet_lat' |
| self.attention = TwoStreamAttentionLat( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=1, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
| self.transport = TwoStreamTransportLat( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=self.n_rotations, |
| crop_size=self.crop_size, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
|
|
|
|
| class TwoStreamClipWithoutSkipsTransporterAgent(TransporterAgent): |
|
|
| def __init__(self, name, cfg, train_ds, test_ds): |
| super().__init__(name, cfg, train_ds, test_ds) |
|
|
| def _build_model(self): |
| |
| stream_one_fcn = 'plain_resnet' |
| stream_two_fcn = 'clip_woskip' |
| self.attention = TwoStreamAttention( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=1, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
| self.transport = TwoStreamTransport( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=self.n_rotations, |
| crop_size=self.crop_size, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
|
|
|
|
| class TwoStreamRN50BertUNetTransporterAgent(TransporterAgent): |
|
|
| def __init__(self, name, cfg, train_ds, test_ds): |
| super().__init__(name, cfg, train_ds, test_ds) |
|
|
| def _build_model(self): |
| |
| stream_one_fcn = 'plain_resnet' |
| stream_two_fcn = 'rn50_bert_unet' |
| self.attention = TwoStreamAttention( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=1, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
| self.transport = TwoStreamTransport( |
| stream_fcn=(stream_one_fcn, stream_two_fcn), |
| in_shape=self.in_shape, |
| n_rotations=self.n_rotations, |
| crop_size=self.crop_size, |
| preprocess=utils.preprocess, |
| cfg=self.cfg, |
| device=self.device_type, |
| ) |
|
|