chenhaojun's picture
Add files using upload-large-folder tool
885b6c5 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from transfer_util import initialize_model
from stage1_models import BasicBlock, ResNet84
import os
import copy
from PIL import Image
import platform
from numbers import Number
import utils
class RandomShiftsAug(nn.Module):
def __init__(self, pad):
super().__init__()
self.pad = pad
def forward(self, x):
n, c, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, 'replicate')
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(-1.0 + eps,
1.0 - eps,
h + 2 * self.pad,
device=x.device,
dtype=x.dtype)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(0,
2 * self.pad + 1,
size=(n, 1, 1, 2),
device=x.device,
dtype=x.dtype)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x,
grid,
padding_mode='zeros',
align_corners=False)
class Identity(nn.Module):
def __init__(self, input_placeholder=None):
super(Identity, self).__init__()
def forward(self, x):
return x
class RLEncoder(nn.Module):
def __init__(self, obs_shape, model_name, device):
super().__init__()
# a wrapper over a non-RL encoder model
self.device = device
assert len(obs_shape) == 3
self.n_input_channel = obs_shape[0]
assert self.n_input_channel % 3 == 0
self.n_images = self.n_input_channel // 3
self.model = self.init_model(model_name)
self.model.fc = Identity()
self.repr_dim = self.model.get_feature_size()
self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
self.channel_mismatch = True
def init_model(self, model_name):
# model name is e.g. resnet6_32channel
n_layer_string, n_channel_string = model_name.split('_')
layer_string_to_layer_list = {
'resnet6': [0, 0, 0, 0],
'resnet10': [1, 1, 1, 1],
'resnet18': [2, 2, 2, 2],
}
channel_string_to_n_channel = {
'32channel': 32,
'64channel': 64,
}
layer_list = layer_string_to_layer_list[n_layer_string]
start_num_channel = channel_string_to_n_channel[n_channel_string]
return ResNet84(BasicBlock, layer_list, start_num_channel=start_num_channel).to(self.device)
def expand_first_layer(self):
# convolutional channel expansion to deal with input mismatch
multiplier = self.n_images
self.model.conv1.weight.data = self.model.conv1.weight.data.repeat(1,multiplier,1,1) / multiplier
means = (0.485, 0.456, 0.406) * multiplier
stds = (0.229, 0.224, 0.225) * multiplier
self.normalize_op = transforms.Normalize(means, stds)
self.channel_mismatch = False
def freeze_bn(self):
# freeze batch norm layers (VRL3 ablation shows modifying how
# batch norm is trained does not affect performance)
for module in self.model.modules():
if isinstance(module, nn.BatchNorm2d):
if hasattr(module, 'weight'):
module.weight.requires_grad_(False)
if hasattr(module, 'bias'):
module.bias.requires_grad_(False)
module.eval()
def get_parameters_that_require_grad(self):
params = []
for name, param in self.named_parameters():
if param.requires_grad == True:
params.append(param)
return params
def transform_obs_tensor_batch(self, obs):
# transform obs batch before put into the pretrained resnet
new_obs = self.normalize_op(obs.float()/255)
return new_obs
def _forward_impl(self, x):
x = self.model.get_features(x)
return x
def forward(self, obs):
o = self.transform_obs_tensor_batch(obs)
h = self._forward_impl(o)
return h
class Stage3ShallowEncoder(nn.Module):
def __init__(self, obs_shape, n_channel):
super().__init__()
assert len(obs_shape) == 3
self.repr_dim = n_channel * 35 * 35
self.n_input_channel = obs_shape[0]
self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2)
self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
self.relu = nn.ReLU(inplace=True)
# TODO here add prediction head so we can do contrastive learning...
self.apply(utils.weight_init)
self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406, 0.485, 0.456, 0.406, 0.485, 0.456, 0.406),
(0.229, 0.224, 0.225, 0.229, 0.224, 0.225, 0.229, 0.224, 0.225))
self.compress = nn.Sequential(nn.Linear(self.repr_dim, 50), nn.LayerNorm(50), nn.Tanh())
self.pred_layer = nn.Linear(50, 50, bias=False)
def transform_obs_tensor_batch(self, obs):
# transform obs batch before put into the pretrained resnet
# correct order might be first augment, then resize, then normalize
# obs = F.interpolate(obs, size=self.pretrained_model_input_size)
new_obs = obs / 255.0 - 0.5
# new_obs = self.normalize_op(new_obs)
return new_obs
def _forward_impl(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.relu(self.conv4(x))
return x
def forward(self, obs):
o = self.transform_obs_tensor_batch(obs)
h = self._forward_impl(o)
h = h.view(h.shape[0], -1)
return h
def get_anchor_output(self, obs, actions=None):
# typically go through conv and then compression layer and then a mlp
# used for UL update
conv_out = self.forward(obs)
compressed = self.compress(conv_out)
pred = self.pred_layer(compressed)
return pred, conv_out
def get_positive_output(self, obs):
# typically go through conv, compression
# used for UL update
conv_out = self.forward(obs)
compressed = self.compress(conv_out)
return compressed
class Encoder(nn.Module):
def __init__(self, obs_shape, n_channel):
super().__init__()
assert len(obs_shape) == 3
self.repr_dim = n_channel * 35 * 35
self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], n_channel, 3, stride=2),
nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
nn.ReLU())
self.apply(utils.weight_init)
def forward(self, obs):
obs = obs / 255.0 - 0.5
h = self.convnet(obs)
h = h.view(h.shape[0], -1)
return h
class IdentityEncoder(nn.Module):
def __init__(self, obs_shape):
super().__init__()
assert len(obs_shape) == 1
self.repr_dim = obs_shape[0]
def forward(self, obs):
return obs
class Actor(nn.Module):
def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
super().__init__()
self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
nn.LayerNorm(feature_dim), nn.Tanh())
self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, action_shape[0]))
self.action_shift=0
self.action_scale=1
self.apply(utils.weight_init)
def forward(self, obs, std):
h = self.trunk(obs)
mu = self.policy(h)
mu = torch.tanh(mu)
mu = mu * self.action_scale + self.action_shift
std = torch.ones_like(mu) * std
dist = utils.TruncatedNormal(mu, std)
return dist
def forward_with_pretanh(self, obs, std):
h = self.trunk(obs)
mu = self.policy(h)
pretanh = mu
mu = torch.tanh(mu)
mu = mu * self.action_scale + self.action_shift
std = torch.ones_like(mu) * std
dist = utils.TruncatedNormal(mu, std)
return dist, pretanh
class Critic(nn.Module):
def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
super().__init__()
self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
nn.LayerNorm(feature_dim), nn.Tanh())
self.Q1 = nn.Sequential(
nn.Linear(feature_dim + action_shape[0], hidden_dim),
nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
self.Q2 = nn.Sequential(
nn.Linear(feature_dim + action_shape[0], hidden_dim),
nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
self.apply(utils.weight_init)
def forward(self, obs, action):
h = self.trunk(obs)
h_action = torch.cat([h, action], dim=-1)
q1 = self.Q1(h_action)
q2 = self.Q2(h_action)
return q1, q2
class VRL3Agent:
def __init__(self, obs_shape, action_shape, device, use_sensor, lr, feature_dim,
hidden_dim, critic_target_tau, num_expl_steps,
update_every_steps, stddev_clip, use_tb, use_data_aug, encoder_lr_scale,
stage1_model_name, safe_q_target_factor, safe_q_threshold, pretanh_penalty, pretanh_threshold,
stage2_update_encoder, cql_weight, cql_temp, cql_n_random, stage2_std, stage2_bc_weight,
stage3_update_encoder, std0, std1, std_n_decay,
stage3_bc_lam0, stage3_bc_lam1):
self.device = device
self.critic_target_tau = critic_target_tau
self.update_every_steps = update_every_steps
self.use_tb = use_tb
self.num_expl_steps = num_expl_steps
self.stage2_std = stage2_std
self.stage2_update_encoder = stage2_update_encoder
if std1 > std0:
std1 = std0
self.stddev_schedule = "linear(%s,%s,%s)" % (str(std0), str(std1), str(std_n_decay))
self.stddev_clip = stddev_clip
self.use_data_aug = use_data_aug
self.safe_q_target_factor = safe_q_target_factor
self.q_threshold = safe_q_threshold
self.pretanh_penalty = pretanh_penalty
self.cql_temp = cql_temp
self.cql_weight = cql_weight
self.cql_n_random = cql_n_random
self.pretanh_threshold = pretanh_threshold
self.stage2_bc_weight = stage2_bc_weight
self.stage3_bc_lam0 = stage3_bc_lam0
self.stage3_bc_lam1 = stage3_bc_lam1
if stage3_update_encoder and encoder_lr_scale > 0 and len(obs_shape) > 1:
self.stage3_update_encoder = True
else:
self.stage3_update_encoder = False
self.encoder = RLEncoder(obs_shape, stage1_model_name, device).to(device)
self.act_dim = action_shape[0]
if use_sensor:
downstream_input_dim = self.encoder.repr_dim + 24
else:
downstream_input_dim = self.encoder.repr_dim
self.actor = Actor(downstream_input_dim, action_shape, feature_dim,
hidden_dim).to(device)
self.critic = Critic(downstream_input_dim, action_shape, feature_dim,
hidden_dim).to(device)
self.critic_target = Critic(downstream_input_dim, action_shape,
feature_dim, hidden_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
encoder_lr = lr * encoder_lr_scale
""" set up encoder optimizer """
self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=encoder_lr)
# data augmentation
self.aug = RandomShiftsAug(pad=4)
self.train()
self.critic_target.train()
def load_pretrained_encoder(self, model_path, verbose=True):
if verbose:
print("Trying to load pretrained model from:", model_path)
checkpoint = torch.load(model_path, map_location=torch.device(self.device))
state_dict = checkpoint['state_dict']
pretrained_dict = {}
# remove `module.` if model was pretrained with distributed mode
for k, v in state_dict.items():
if 'module.' in k:
name = k[7:]
else:
name = k
pretrained_dict[name] = v
self.encoder.model.load_state_dict(pretrained_dict, strict=False)
if verbose:
print("Pretrained model loaded!")
def switch_to_RL_stages(self, verbose=True):
# run convolutional channel expansion to match input shape
self.encoder.expand_first_layer()
if verbose:
print("Convolutional channel expansion finished: now can take in %d images as input." % self.encoder.n_images)
def train(self, training=True):
self.training = training
self.encoder.train(training)
self.actor.train(training)
self.critic.train(training)
def act(self, obs, step, eval_mode, obs_sensor=None, is_tensor_input=False, force_action_std=None):
"""
obs: 3x84x84, uint8, [0,255]
"""
# eval_mode should be False when taking an exploration action in stage 3
# eval_mode should be True when evaluate agent performance
if force_action_std == None:
stddev = utils.schedule(self.stddev_schedule, step)
if step < self.num_expl_steps and not eval_mode:
action = np.random.uniform(0, 1, (self.act_dim,)).astype(np.float32)
return action
else:
stddev = force_action_std
if is_tensor_input:
obs = self.encoder(obs)
else:
obs = torch.as_tensor(obs, device=self.device)
obs = self.encoder(obs.unsqueeze(0))
if obs_sensor is not None:
obs_sensor = torch.as_tensor(obs_sensor, device=self.device)
obs_sensor = obs_sensor.unsqueeze(0)
obs_combined = torch.cat([obs, obs_sensor], dim=1)
else:
obs_combined = obs
dist = self.actor(obs_combined, stddev)
if eval_mode:
action = dist.mean
else:
action = dist.sample(clip=None)
if step < self.num_expl_steps:
action.uniform_(-1.0, 1.0)
return action.cpu().numpy()[0]
def update(self, replay_iter, step, stage, use_sensor):
# for stage 2 and 3, we use the same functions but with different hyperparameters
assert stage in (2, 3)
metrics = dict()
if stage == 2:
update_encoder = self.stage2_update_encoder
stddev = self.stage2_std
conservative_loss_weight = self.cql_weight
bc_weight = self.stage2_bc_weight
if stage == 3:
if step % self.update_every_steps != 0:
return metrics
update_encoder = self.stage3_update_encoder
stddev = utils.schedule(self.stddev_schedule, step)
conservative_loss_weight = 0
# compute stage 3 BC weight
bc_data_per_iter = 40000
i_iter = step // bc_data_per_iter
bc_weight = self.stage3_bc_lam0 * self.stage3_bc_lam1 ** i_iter
# batch data
batch = next(replay_iter)
if use_sensor: # TODO might want to...?
obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next = utils.to_torch(batch, self.device)
else:
obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device)
obs_sensor, obs_sensor_next = None, None
# augment
if self.use_data_aug:
obs = self.aug(obs.float())
next_obs = self.aug(next_obs.float())
else:
obs = obs.float()
next_obs = next_obs.float()
# encode
if update_encoder:
obs = self.encoder(obs)
else:
with torch.no_grad():
obs = self.encoder(obs)
with torch.no_grad():
next_obs = self.encoder(next_obs)
# concatenate obs with additional sensor observation if needed
obs_combined = torch.cat([obs, obs_sensor], dim=1) if obs_sensor is not None else obs
obs_next_combined = torch.cat([next_obs, obs_sensor_next], dim=1) if obs_sensor_next is not None else next_obs
# update critic
metrics.update(self.update_critic_vrl3(obs_combined, action, reward, discount, obs_next_combined,
stddev, update_encoder, conservative_loss_weight))
# update actor, following previous works, we do not use actor gradient for encoder update
metrics.update(self.update_actor_vrl3(obs_combined.detach(), action, stddev, bc_weight,
self.pretanh_penalty, self.pretanh_threshold))
metrics['batch_reward'] = reward.mean().item()
# update critic target networks
utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
return metrics
def update_critic_vrl3(self, obs, action, reward, discount, next_obs, stddev, update_encoder, conservative_loss_weight):
metrics = dict()
batch_size = obs.shape[0]
"""
STANDARD Q LOSS COMPUTATION:
- get standard Q loss first, this is the same as in any other online RL methods
- except for the safe Q technique, which controls how large the Q value can be
"""
with torch.no_grad():
dist = self.actor(next_obs, stddev)
next_action = dist.sample(clip=self.stddev_clip)
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
target_V = torch.min(target_Q1, target_Q2)
target_Q = reward + (discount * target_V)
if self.safe_q_target_factor < 1:
target_Q[target_Q > (self.q_threshold + 1)] = self.q_threshold + (target_Q[target_Q > (self.q_threshold+1)] - self.q_threshold) ** self.safe_q_target_factor
Q1, Q2 = self.critic(obs, action)
critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
"""
CONSERVATIVE Q LOSS COMPUTATION:
- sample random actions, actions from policy and next actions from policy, as done in CQL authors' code
(though this detail is not really discussed in the CQL paper)
- only compute this loss when conservative loss weight > 0
"""
if conservative_loss_weight > 0:
random_actions = (torch.rand((batch_size * self.cql_n_random, self.act_dim), device=self.device) - 0.5) * 2
dist = self.actor(obs, stddev)
current_actions = dist.sample(clip=self.stddev_clip)
dist = self.actor(next_obs, stddev)
next_current_actions = dist.sample(clip=self.stddev_clip)
# now get Q values for all these actions (for both Q networks)
obs_repeat = obs.unsqueeze(1).repeat(1, self.cql_n_random, 1).view(obs.shape[0] * self.cql_n_random,
obs.shape[1])
Q1_rand, Q2_rand = self.critic(obs_repeat,
random_actions) # TODO might want to double check the logic here see if the repeat is correct
Q1_rand = Q1_rand.view(obs.shape[0], self.cql_n_random)
Q2_rand = Q2_rand.view(obs.shape[0], self.cql_n_random)
Q1_curr, Q2_curr = self.critic(obs, current_actions)
Q1_curr_next, Q2_curr_next = self.critic(obs, next_current_actions)
# now concat all these Q values together
Q1_cat = torch.cat([Q1_rand, Q1, Q1_curr, Q1_curr_next], 1)
Q2_cat = torch.cat([Q2_rand, Q2, Q2_curr, Q2_curr_next], 1)
cql_min_q1_loss = torch.logsumexp(Q1_cat / self.cql_temp,
dim=1, ).mean() * conservative_loss_weight * self.cql_temp
cql_min_q2_loss = torch.logsumexp(Q2_cat / self.cql_temp,
dim=1, ).mean() * conservative_loss_weight * self.cql_temp
"""Subtract the log likelihood of data"""
conservative_q_loss = cql_min_q1_loss + cql_min_q2_loss - (Q1.mean() + Q2.mean()) * conservative_loss_weight
critic_loss_combined = critic_loss + conservative_q_loss
else:
critic_loss_combined = critic_loss
# logging
metrics['critic_target_q'] = target_Q.mean().item()
metrics['critic_q1'] = Q1.mean().item()
metrics['critic_q2'] = Q2.mean().item()
metrics['critic_loss'] = critic_loss.item()
# if needed, also update encoder with critic loss
if update_encoder:
self.encoder_opt.zero_grad(set_to_none=True)
self.critic_opt.zero_grad(set_to_none=True)
critic_loss_combined.backward()
self.critic_opt.step()
if update_encoder:
self.encoder_opt.step()
return metrics
def update_actor_vrl3(self, obs, action, stddev, bc_weight, pretanh_penalty, pretanh_threshold):
metrics = dict()
"""
get standard actor loss
"""
dist, pretanh = self.actor.forward_with_pretanh(obs, stddev)
current_action = dist.sample(clip=self.stddev_clip)
log_prob = dist.log_prob(current_action).sum(-1, keepdim=True)
Q1, Q2 = self.critic(obs, current_action)
Q = torch.min(Q1, Q2)
actor_loss = -Q.mean()
"""
add BC loss
"""
if bc_weight > 0:
# get mean action with no action noise (though this might not be necessary)
stddev_bc = 0
dist_bc = self.actor(obs, stddev_bc)
current_mean_action = dist_bc.sample(clip=self.stddev_clip)
actor_loss_bc = F.mse_loss(current_mean_action, action) * bc_weight
else:
actor_loss_bc = torch.FloatTensor([0]).to(self.device)
"""
add pretanh penalty (might not be necessary for Adroit)
"""
pretanh_loss = 0
if pretanh_penalty > 0:
pretanh_loss = pretanh.abs() - pretanh_threshold
pretanh_loss[pretanh_loss < 0] = 0
pretanh_loss = (pretanh_loss ** 2).mean() * pretanh_penalty
"""
combine actor losses and optimize
"""
actor_loss_combined = actor_loss + actor_loss_bc + pretanh_loss
self.actor_opt.zero_grad(set_to_none=True)
actor_loss_combined.backward()
self.actor_opt.step()
metrics['actor_loss'] = actor_loss.item()
metrics['actor_loss_bc'] = actor_loss_bc.item()
metrics['actor_logprob'] = log_prob.mean().item()
metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
metrics['abs_pretanh'] = pretanh.abs().mean().item()
metrics['max_abs_pretanh'] = pretanh.abs().max().item()
return metrics
def to(self, device):
self.actor.to(device)
self.critic.to(device)
self.encoder.to(device)
self.device = device