|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
import time |
|
|
|
|
|
|
|
|
import asyncio |
|
|
from PIL import Image |
|
|
import datetime |
|
|
import pickle as pkl |
|
|
import random |
|
|
import cv2 |
|
|
|
|
|
from prompt import ( |
|
|
gemini_free_query_env_prompts, gemini_summary_env_prompts, |
|
|
gemini_free_query_prompt1, gemini_free_query_prompt2, |
|
|
gemini_single_query_env_prompts, |
|
|
gpt_free_query_env_prompts, gpt_summary_env_prompts, |
|
|
) |
|
|
from vlms.gemini_infer import gemini_query_2, gemini_query_1 |
|
|
from conv_net import CNN, fanin_init |
|
|
|
|
|
device = 'cuda' |
|
|
|
|
|
def gen_net(in_size=1, out_size=1, H=128, n_layers=3, activation='tanh'): |
|
|
net = [] |
|
|
for i in range(n_layers): |
|
|
net.append(nn.Linear(in_size, H)) |
|
|
net.append(nn.LeakyReLU()) |
|
|
in_size = H |
|
|
net.append(nn.Linear(in_size, out_size)) |
|
|
if activation == 'tanh': |
|
|
net.append(nn.Tanh()) |
|
|
elif activation == 'sig': |
|
|
net.append(nn.Sigmoid()) |
|
|
else: |
|
|
net.append(nn.ReLU()) |
|
|
|
|
|
return net |
|
|
|
|
|
def gen_image_net(image_height, image_width, |
|
|
conv_kernel_sizes=[5, 3, 3 ,3], |
|
|
conv_n_channels=[16, 32, 64, 128], |
|
|
conv_strides=[3, 2, 2, 2]): |
|
|
conv_args=dict( |
|
|
kernel_sizes=conv_kernel_sizes, |
|
|
n_channels=conv_n_channels, |
|
|
strides=conv_strides, |
|
|
output_size=1, |
|
|
) |
|
|
conv_kwargs=dict( |
|
|
hidden_sizes=[], |
|
|
batch_norm_conv=False, |
|
|
batch_norm_fc=False, |
|
|
) |
|
|
|
|
|
return CNN( |
|
|
**conv_args, |
|
|
paddings=np.zeros(len(conv_args['kernel_sizes']), dtype=np.int64), |
|
|
input_height=image_height, |
|
|
input_width=image_width, |
|
|
input_channels=3, |
|
|
init_w=1e-3, |
|
|
hidden_init=fanin_init, |
|
|
**conv_kwargs |
|
|
) |
|
|
|
|
|
def gen_image_net2(): |
|
|
from torchvision.models.resnet import ResNet |
|
|
from torchvision.models.resnet import BasicBlock |
|
|
|
|
|
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=1) |
|
|
return model |
|
|
|
|
|
def KCenterGreedy(obs, full_obs, num_new_sample): |
|
|
selected_index = [] |
|
|
current_index = list(range(obs.shape[0])) |
|
|
new_obs = obs |
|
|
new_full_obs = full_obs |
|
|
start_time = time.time() |
|
|
for count in range(num_new_sample): |
|
|
dist = compute_smallest_dist(new_obs, new_full_obs) |
|
|
max_index = torch.argmax(dist) |
|
|
max_index = max_index.item() |
|
|
|
|
|
if count == 0: |
|
|
selected_index.append(max_index) |
|
|
else: |
|
|
selected_index.append(current_index[max_index]) |
|
|
current_index = current_index[0:max_index] + current_index[max_index+1:] |
|
|
|
|
|
new_obs = obs[current_index] |
|
|
new_full_obs = np.concatenate([ |
|
|
full_obs, |
|
|
obs[selected_index]], |
|
|
axis=0) |
|
|
return selected_index |
|
|
|
|
|
def compute_smallest_dist(obs, full_obs): |
|
|
obs = torch.from_numpy(obs).float() |
|
|
full_obs = torch.from_numpy(full_obs).float() |
|
|
batch_size = 100 |
|
|
with torch.no_grad(): |
|
|
total_dists = [] |
|
|
for full_idx in range(len(obs) // batch_size + 1): |
|
|
full_start = full_idx * batch_size |
|
|
if full_start < len(obs): |
|
|
full_end = (full_idx + 1) * batch_size |
|
|
dists = [] |
|
|
for idx in range(len(full_obs) // batch_size + 1): |
|
|
start = idx * batch_size |
|
|
if start < len(full_obs): |
|
|
end = (idx + 1) * batch_size |
|
|
dist = torch.norm( |
|
|
obs[full_start:full_end, None, :].to(device) - full_obs[None, start:end, :].to(device), dim=-1, p=2 |
|
|
) |
|
|
dists.append(dist) |
|
|
dists = torch.cat(dists, dim=1) |
|
|
small_dists = torch.torch.min(dists, dim=1).values |
|
|
total_dists.append(small_dists) |
|
|
|
|
|
total_dists = torch.cat(total_dists) |
|
|
return total_dists.unsqueeze(1) |
|
|
|
|
|
class RewardModel: |
|
|
def __init__(self, ds, da, |
|
|
ensemble_size=3, lr=3e-4, mb_size = 128, size_segment=1, |
|
|
max_size=100, activation='tanh', capacity=5e5, |
|
|
large_batch=1, label_margin=0.0, |
|
|
teacher_beta=-1, teacher_gamma=1, |
|
|
teacher_eps_mistake=0, |
|
|
teacher_eps_skip=0, |
|
|
teacher_eps_equal=0, |
|
|
|
|
|
|
|
|
vlm_label=True, |
|
|
env_name="CartPole-v1", |
|
|
vlm="gemini_free_form", |
|
|
clip_prompt=None, |
|
|
log_dir=None, |
|
|
flip_vlm_label=False, |
|
|
save_query_interval=25, |
|
|
cached_label_path=None, |
|
|
|
|
|
|
|
|
reward_model_layers=3, |
|
|
reward_model_H=256, |
|
|
image_reward=True, |
|
|
image_height=128, |
|
|
image_width=128, |
|
|
resize_factor=1, |
|
|
resnet=False, |
|
|
conv_kernel_sizes=[5, 3, 3 ,3], |
|
|
conv_n_channels=[16, 32, 64, 128], |
|
|
conv_strides=[3, 2, 2, 2], |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
|
|
|
self.ds = ds |
|
|
self.da = da |
|
|
self.de = ensemble_size |
|
|
self.lr = lr |
|
|
self.ensemble = [] |
|
|
self.paramlst = [] |
|
|
self.opt = None |
|
|
self.model = None |
|
|
self.max_size = max_size |
|
|
self.activation = activation |
|
|
self.size_segment = size_segment |
|
|
|
|
|
self.capacity = int(capacity) |
|
|
self.reward_model_layers = reward_model_layers |
|
|
self.reward_model_H = reward_model_H |
|
|
self.image_reward = image_reward |
|
|
self.resnet = resnet |
|
|
self.conv_kernel_sizes = conv_kernel_sizes |
|
|
self.conv_n_channels = conv_n_channels |
|
|
self.conv_strides = conv_strides |
|
|
|
|
|
if not image_reward: |
|
|
self.buffer_seg1 = np.empty((self.capacity, size_segment, self.ds+self.da), dtype=np.float32) |
|
|
self.buffer_seg2 = np.empty((self.capacity, size_segment, self.ds+self.da), dtype=np.float32) |
|
|
else: |
|
|
assert self.size_segment == 1 |
|
|
self.buffer_seg1 = np.empty((self.capacity, 1, image_height, image_width, 3), dtype=np.uint8) |
|
|
self.buffer_seg2 = np.empty((self.capacity, 1, image_height, image_width, 3), dtype=np.uint8) |
|
|
self.image_height = image_height |
|
|
self.image_width = image_width |
|
|
self.resize_factor = resize_factor |
|
|
|
|
|
self.buffer_label = np.empty((self.capacity, 1), dtype=np.float32) |
|
|
self.buffer_index = 0 |
|
|
self.buffer_full = False |
|
|
|
|
|
self.construct_ensemble() |
|
|
self.inputs = [] |
|
|
self.targets = [] |
|
|
self.raw_actions = [] |
|
|
self.img_inputs = [] |
|
|
self.mb_size = mb_size |
|
|
self.origin_mb_size = mb_size |
|
|
if not image_reward: |
|
|
self.train_batch_size = 128 |
|
|
else: |
|
|
if not self.resnet: |
|
|
self.train_batch_size = 64 |
|
|
else: |
|
|
self.train_batch_size = 32 |
|
|
self.CEloss = nn.CrossEntropyLoss() |
|
|
self.running_means = [] |
|
|
self.running_stds = [] |
|
|
self.best_seg = [] |
|
|
self.best_label = [] |
|
|
self.best_action = [] |
|
|
self.large_batch = large_batch |
|
|
|
|
|
|
|
|
self.teacher_beta = teacher_beta |
|
|
self.teacher_gamma = teacher_gamma |
|
|
self.teacher_eps_mistake = teacher_eps_mistake |
|
|
self.teacher_eps_equal = teacher_eps_equal |
|
|
self.teacher_eps_skip = teacher_eps_skip |
|
|
self.teacher_thres_skip = 0 |
|
|
self.teacher_thres_equal = 0 |
|
|
|
|
|
self.label_margin = label_margin |
|
|
self.label_target = 1 - 2*self.label_margin |
|
|
|
|
|
|
|
|
self.vlm_label = vlm_label |
|
|
self.env_name = env_name |
|
|
self.vlm = vlm |
|
|
self.clip_prompt = clip_prompt |
|
|
self.vlm_label_acc = 0 |
|
|
self.log_dir = log_dir |
|
|
self.flip_vlm_label = flip_vlm_label |
|
|
self.train_times = 0 |
|
|
self.save_query_interval = save_query_interval |
|
|
|
|
|
|
|
|
file_path = os.path.abspath(__file__) |
|
|
dir_path = os.path.dirname(file_path) |
|
|
self.cached_label_path = "{}/{}".format(dir_path, cached_label_path) |
|
|
self.read_cache_idx = 0 |
|
|
if self.cached_label_path is not None: |
|
|
all_cached_labels = sorted(os.listdir(self.cached_label_path)) |
|
|
self.all_cached_labels = [os.path.join(self.cached_label_path, x) for x in all_cached_labels] |
|
|
|
|
|
def eval(self,): |
|
|
for i in range(self.de): |
|
|
self.ensemble[i].eval() |
|
|
|
|
|
def train(self,): |
|
|
for i in range(self.de): |
|
|
self.ensemble[i].train() |
|
|
|
|
|
def softXEnt_loss(self, input, target): |
|
|
logprobs = torch.nn.functional.log_softmax (input, dim = 1) |
|
|
return -(target * logprobs).sum() / input.shape[0] |
|
|
|
|
|
def change_batch(self, new_frac): |
|
|
self.mb_size = int(self.origin_mb_size*new_frac) |
|
|
|
|
|
def set_batch(self, new_batch): |
|
|
self.mb_size = int(new_batch) |
|
|
|
|
|
def set_teacher_thres_skip(self, new_margin): |
|
|
self.teacher_thres_skip = new_margin * self.teacher_eps_skip |
|
|
|
|
|
def set_teacher_thres_equal(self, new_margin): |
|
|
self.teacher_thres_equal = new_margin * self.teacher_eps_equal |
|
|
|
|
|
def construct_ensemble(self): |
|
|
for i in range(self.de): |
|
|
if not self.image_reward: |
|
|
model = nn.Sequential(*gen_net(in_size=self.ds+self.da, |
|
|
out_size=1, H=self.reward_model_H, n_layers=self.reward_model_layers, |
|
|
activation=self.activation)).float().to(device) |
|
|
else: |
|
|
if not self.resnet: |
|
|
model = gen_image_net(self.image_height, self.image_width, self.conv_kernel_sizes, self.conv_n_channels, self.conv_strides).float().to(device) |
|
|
else: |
|
|
model = gen_image_net2().float().to(device) |
|
|
|
|
|
self.ensemble.append(model) |
|
|
self.paramlst.extend(model.parameters()) |
|
|
|
|
|
self.opt = torch.optim.Adam(self.paramlst, lr = self.lr) |
|
|
|
|
|
def add_data(self, obs, act, rew, done, img=None): |
|
|
sa_t = np.concatenate([obs, act], axis=-1) |
|
|
r_t = rew |
|
|
|
|
|
flat_input = sa_t.reshape(1, self.da+self.ds) |
|
|
r_t = np.array(r_t) |
|
|
flat_target = r_t.reshape(1, 1) |
|
|
if img is not None: |
|
|
flat_img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2]) |
|
|
|
|
|
init_data = len(self.inputs) == 0 |
|
|
if init_data: |
|
|
self.inputs.append(flat_input) |
|
|
self.targets.append(flat_target) |
|
|
if img is not None: |
|
|
self.img_inputs.append(flat_img) |
|
|
elif done: |
|
|
if 'Cloth' not in self.env_name: |
|
|
self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input]) |
|
|
self.targets[-1] = np.concatenate([self.targets[-1], flat_target]) |
|
|
if img is not None: |
|
|
self.img_inputs[-1] = np.concatenate([self.img_inputs[-1], flat_img], axis=0) |
|
|
|
|
|
|
|
|
if len(self.inputs) > self.max_size: |
|
|
self.inputs = self.inputs[1:] |
|
|
self.targets = self.targets[1:] |
|
|
if img is not None: |
|
|
self.img_inputs = self.img_inputs[1:] |
|
|
self.inputs.append([]) |
|
|
self.targets.append([]) |
|
|
if img is not None: |
|
|
self.img_inputs.append([]) |
|
|
else: |
|
|
self.inputs.append([flat_input]) |
|
|
self.targets.append([flat_target]) |
|
|
if img is not None: |
|
|
self.img_inputs.append([flat_img]) |
|
|
|
|
|
|
|
|
if len(self.inputs) > self.max_size: |
|
|
self.inputs = self.inputs[1:] |
|
|
self.targets = self.targets[1:] |
|
|
if img is not None: |
|
|
self.img_inputs = self.img_inputs[1:] |
|
|
else: |
|
|
if len(self.inputs[-1]) == 0: |
|
|
self.inputs[-1] = flat_input |
|
|
self.targets[-1] = flat_target |
|
|
if img is not None: |
|
|
self.img_inputs[-1] = flat_img |
|
|
else: |
|
|
self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input]) |
|
|
self.targets[-1] = np.concatenate([self.targets[-1], flat_target]) |
|
|
if img is not None: |
|
|
self.img_inputs[-1] = np.concatenate([self.img_inputs[-1], flat_img], axis=0) |
|
|
|
|
|
def add_data_batch(self, obses, rewards): |
|
|
num_env = obses.shape[0] |
|
|
for index in range(num_env): |
|
|
self.inputs.append(obses[index]) |
|
|
self.targets.append(rewards[index]) |
|
|
|
|
|
def get_rank_probability(self, x_1, x_2): |
|
|
|
|
|
probs = [] |
|
|
for member in range(self.de): |
|
|
probs.append(self.p_hat_member(x_1, x_2, member=member).cpu().numpy()) |
|
|
probs = np.array(probs) |
|
|
|
|
|
return np.mean(probs, axis=0), np.std(probs, axis=0) |
|
|
|
|
|
def get_entropy(self, x_1, x_2): |
|
|
|
|
|
probs = [] |
|
|
for member in range(self.de): |
|
|
probs.append(self.p_hat_entropy(x_1, x_2, member=member).cpu().numpy()) |
|
|
probs = np.array(probs) |
|
|
return np.mean(probs, axis=0), np.std(probs, axis=0) |
|
|
|
|
|
def p_hat_member(self, x_1, x_2, member=-1): |
|
|
|
|
|
with torch.no_grad(): |
|
|
r_hat1 = self.r_hat_member(x_1, member=member) |
|
|
r_hat2 = self.r_hat_member(x_2, member=member) |
|
|
r_hat1 = r_hat1.sum(axis=1) |
|
|
r_hat2 = r_hat2.sum(axis=1) |
|
|
r_hat = torch.cat([r_hat1, r_hat2], axis=-1) |
|
|
|
|
|
|
|
|
return F.softmax(r_hat, dim=-1)[:,0] |
|
|
|
|
|
def p_hat_entropy(self, x_1, x_2, member=-1): |
|
|
|
|
|
with torch.no_grad(): |
|
|
r_hat1 = self.r_hat_member(x_1, member=member) |
|
|
r_hat2 = self.r_hat_member(x_2, member=member) |
|
|
r_hat1 = r_hat1.sum(axis=1) |
|
|
r_hat2 = r_hat2.sum(axis=1) |
|
|
r_hat = torch.cat([r_hat1, r_hat2], axis=-1) |
|
|
|
|
|
ent = F.softmax(r_hat, dim=-1) * F.log_softmax(r_hat, dim=-1) |
|
|
ent = ent.sum(axis=-1).abs() |
|
|
return ent |
|
|
|
|
|
def r_hat_member(self, x, member=-1): |
|
|
|
|
|
return self.ensemble[member](torch.from_numpy(x).float().to(device)) |
|
|
|
|
|
def r_hat(self, x): |
|
|
|
|
|
|
|
|
r_hats = [] |
|
|
for member in range(self.de): |
|
|
r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy()) |
|
|
r_hats = np.array(r_hats) |
|
|
return np.mean(r_hats) |
|
|
|
|
|
def r_hat_batch(self, x): |
|
|
|
|
|
|
|
|
r_hats = [] |
|
|
for member in range(self.de): |
|
|
r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy()) |
|
|
r_hats = np.array(r_hats) |
|
|
|
|
|
return np.mean(r_hats, axis=0) |
|
|
|
|
|
def save(self, model_dir, step): |
|
|
for member in range(self.de): |
|
|
torch.save( |
|
|
self.ensemble[member].state_dict(), '%s/reward_model_%s_%s.pt' % (model_dir, step, member) |
|
|
) |
|
|
|
|
|
def load(self, model_dir, step): |
|
|
file_dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
model_dir = os.path.join(file_dir, model_dir) |
|
|
for member in range(self.de): |
|
|
self.ensemble[member].load_state_dict( |
|
|
torch.load('%s/reward_model_%s_%s.pt' % (model_dir, step, member)) |
|
|
) |
|
|
|
|
|
def get_train_acc(self): |
|
|
ensemble_acc = np.array([0 for _ in range(self.de)]) |
|
|
max_len = self.capacity if self.buffer_full else self.buffer_index |
|
|
total_batch_index = np.random.permutation(max_len) |
|
|
batch_size = 256 |
|
|
num_epochs = int(np.ceil(max_len/batch_size)) |
|
|
|
|
|
total = 0 |
|
|
for epoch in range(num_epochs): |
|
|
last_index = (epoch+1)*batch_size |
|
|
if (epoch+1)*batch_size > max_len: |
|
|
last_index = max_len |
|
|
|
|
|
sa_t_1 = self.buffer_seg1[epoch*batch_size:last_index] |
|
|
sa_t_2 = self.buffer_seg2[epoch*batch_size:last_index] |
|
|
labels = self.buffer_label[epoch*batch_size:last_index] |
|
|
labels = torch.from_numpy(labels.flatten()).long().to(device) |
|
|
total += labels.size(0) |
|
|
for member in range(self.de): |
|
|
|
|
|
r_hat1 = self.r_hat_member(sa_t_1, member=member) |
|
|
r_hat2 = self.r_hat_member(sa_t_2, member=member) |
|
|
r_hat1 = r_hat1.sum(axis=1) |
|
|
r_hat2 = r_hat2.sum(axis=1) |
|
|
r_hat = torch.cat([r_hat1, r_hat2], axis=-1) |
|
|
_, predicted = torch.max(r_hat.data, 1) |
|
|
correct = (predicted == labels).sum().item() |
|
|
ensemble_acc[member] += correct |
|
|
|
|
|
ensemble_acc = ensemble_acc / total |
|
|
return np.mean(ensemble_acc) |
|
|
|
|
|
def get_queries(self, mb_size=20): |
|
|
len_traj, max_len = len(self.inputs[0]), len(self.inputs) |
|
|
|
|
|
if len(self.inputs[-1]) < len_traj: |
|
|
max_len = max_len - 1 |
|
|
|
|
|
|
|
|
train_inputs = np.array(self.inputs[:max_len]) |
|
|
train_targets = np.array(self.targets[:max_len]) |
|
|
if self.vlm_label or self.image_reward: |
|
|
train_images = np.array(self.img_inputs[:max_len]) |
|
|
if 'Cloth' in self.env_name: |
|
|
train_images = train_images.squeeze(1) |
|
|
|
|
|
batch_index_2 = np.random.choice(max_len, size=mb_size, replace=True) |
|
|
sa_t_2 = train_inputs[batch_index_2] |
|
|
r_t_2 = train_targets[batch_index_2] |
|
|
if self.vlm_label or self.image_reward: |
|
|
img_t_2 = train_images[batch_index_2] |
|
|
|
|
|
batch_index_1 = np.random.choice(max_len, size=mb_size, replace=True) |
|
|
sa_t_1 = train_inputs[batch_index_1] |
|
|
r_t_1 = train_targets[batch_index_1] |
|
|
if self.vlm_label or self.image_reward: |
|
|
img_t_1 = train_images[batch_index_1] |
|
|
|
|
|
sa_t_1 = sa_t_1.reshape(-1, sa_t_1.shape[-1]) |
|
|
r_t_1 = r_t_1.reshape(-1, r_t_1.shape[-1]) |
|
|
sa_t_2 = sa_t_2.reshape(-1, sa_t_2.shape[-1]) |
|
|
r_t_2 = r_t_2.reshape(-1, r_t_2.shape[-1]) |
|
|
if self.vlm_label or self.image_reward: |
|
|
img_t_1 = img_t_1.reshape(-1, img_t_1.shape[2], img_t_1.shape[3], img_t_1.shape[4]) |
|
|
img_t_2 = img_t_2.reshape(-1, img_t_2.shape[2], img_t_2.shape[3], img_t_2.shape[4]) |
|
|
|
|
|
|
|
|
time_index = np.array([list(range(i*len_traj, i*len_traj+self.size_segment)) for i in range(mb_size)]) |
|
|
if 'Cloth' not in self.env_name: |
|
|
random_idx_2 = np.random.choice(len_traj-self.size_segment, size=mb_size, replace=True).reshape(-1,1) |
|
|
time_index_2 = time_index + random_idx_2 |
|
|
random_idx_1 = np.random.choice(len_traj-self.size_segment, size=mb_size, replace=True).reshape(-1,1) |
|
|
time_index_1 = time_index + random_idx_1 |
|
|
else: |
|
|
time_index_2 = time_index |
|
|
time_index_1 = time_index |
|
|
if self.vlm_label or self.image_reward: |
|
|
if self.vlm_label == 1 or self.image_reward: |
|
|
image_time_index = np.array([[i*len_traj+self.size_segment - 1] for i in range(mb_size)]) |
|
|
else: |
|
|
interval = self.size_segment // self.vlm_label |
|
|
image_time_index = np.array([[i * len_traj + self.size_segment - 1 - j * interval for j in range(self.vlm_label - 1, -1, -1)] for i in range(mb_size)]) |
|
|
image_time_index = np.maximum(image_time_index, 0) |
|
|
|
|
|
if 'Cloth' not in self.env_name: |
|
|
image_time_index_2 = image_time_index + random_idx_2 |
|
|
image_time_index_1 = image_time_index + random_idx_1 |
|
|
else: |
|
|
image_time_index_2 = image_time_index |
|
|
image_time_index_1 = image_time_index |
|
|
|
|
|
sa_t_1 = np.take(sa_t_1, time_index_1, axis=0) |
|
|
r_t_1 = np.take(r_t_1, time_index_1, axis=0) |
|
|
sa_t_2 = np.take(sa_t_2, time_index_2, axis=0) |
|
|
r_t_2 = np.take(r_t_2, time_index_2, axis=0) |
|
|
if self.vlm_label or self.image_reward: |
|
|
img_t_1 = np.take(img_t_1, image_time_index_1, axis=0) |
|
|
img_t_2 = np.take(img_t_2, image_time_index_2, axis=0) |
|
|
|
|
|
batch_size, horizon, image_height, image_width, _ = img_t_1.shape |
|
|
|
|
|
transposed_images = np.transpose(img_t_1, (0, 2, 1, 3, 4)) |
|
|
img_t_1 = transposed_images.reshape(batch_size, image_height, horizon * image_width, 3) |
|
|
transposed_images = np.transpose(img_t_2, (0, 2, 1, 3, 4)) |
|
|
img_t_2 = transposed_images.reshape(batch_size, image_height, horizon * image_width, 3) |
|
|
|
|
|
if not self.vlm_label and not self.image_reward: |
|
|
return sa_t_1, sa_t_2, r_t_1, r_t_2 |
|
|
else: |
|
|
return sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2 |
|
|
|
|
|
def put_queries(self, sa_t_1, sa_t_2, labels): |
|
|
total_sample = sa_t_1.shape[0] |
|
|
next_index = self.buffer_index + total_sample |
|
|
|
|
|
|
|
|
if next_index >= self.capacity: |
|
|
self.buffer_full = True |
|
|
maximum_index = self.capacity - self.buffer_index |
|
|
np.copyto(self.buffer_seg1[self.buffer_index:self.capacity], sa_t_1[:maximum_index]) |
|
|
np.copyto(self.buffer_seg2[self.buffer_index:self.capacity], sa_t_2[:maximum_index]) |
|
|
np.copyto(self.buffer_label[self.buffer_index:self.capacity], labels[:maximum_index]) |
|
|
|
|
|
remain = total_sample - (maximum_index) |
|
|
if remain > 0: |
|
|
np.copyto(self.buffer_seg1[0:remain], sa_t_1[maximum_index:]) |
|
|
np.copyto(self.buffer_seg2[0:remain], sa_t_2[maximum_index:]) |
|
|
np.copyto(self.buffer_label[0:remain], labels[maximum_index:]) |
|
|
|
|
|
self.buffer_index = remain |
|
|
else: |
|
|
if self.image_reward: |
|
|
sa_t_1 = sa_t_1.reshape(sa_t_1.shape[0], 1, sa_t_1.shape[1], sa_t_1.shape[2], sa_t_1.shape[3]) |
|
|
sa_t_2 = sa_t_2.reshape(sa_t_2.shape[0], 1, sa_t_2.shape[1], sa_t_2.shape[2], sa_t_2.shape[3]) |
|
|
np.copyto(self.buffer_seg1[self.buffer_index:next_index], sa_t_1) |
|
|
np.copyto(self.buffer_seg2[self.buffer_index:next_index], sa_t_2) |
|
|
np.copyto(self.buffer_label[self.buffer_index:next_index], labels) |
|
|
self.buffer_index = next_index |
|
|
|
|
|
def get_label(self, sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1=None, img_t_2=None): |
|
|
sum_r_t_1 = np.sum(r_t_1, axis=1) |
|
|
sum_r_t_2 = np.sum(r_t_2, axis=1) |
|
|
|
|
|
|
|
|
if self.teacher_thres_skip > 0: |
|
|
max_r_t = np.maximum(sum_r_t_1, sum_r_t_2) |
|
|
max_index = (max_r_t > self.teacher_thres_skip).reshape(-1) |
|
|
if sum(max_index) == 0: |
|
|
return None, None, None, None, [] |
|
|
|
|
|
sa_t_1 = sa_t_1[max_index] |
|
|
sa_t_2 = sa_t_2[max_index] |
|
|
r_t_1 = r_t_1[max_index] |
|
|
r_t_2 = r_t_2[max_index] |
|
|
sum_r_t_1 = np.sum(r_t_1, axis=1) |
|
|
sum_r_t_2 = np.sum(r_t_2, axis=1) |
|
|
|
|
|
|
|
|
margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) < self.teacher_thres_equal).reshape(-1) |
|
|
|
|
|
|
|
|
seg_size = r_t_1.shape[1] |
|
|
temp_r_t_1 = r_t_1.copy() |
|
|
temp_r_t_2 = r_t_2.copy() |
|
|
for index in range(seg_size-1): |
|
|
temp_r_t_1[:,:index+1] *= self.teacher_gamma |
|
|
temp_r_t_2[:,:index+1] *= self.teacher_gamma |
|
|
sum_r_t_1 = np.sum(temp_r_t_1, axis=1) |
|
|
sum_r_t_2 = np.sum(temp_r_t_2, axis=1) |
|
|
|
|
|
rational_labels = 1*(sum_r_t_1 < sum_r_t_2) |
|
|
if self.teacher_beta > 0: |
|
|
r_hat = torch.cat([torch.Tensor(sum_r_t_1), |
|
|
torch.Tensor(sum_r_t_2)], axis=-1) |
|
|
r_hat = r_hat*self.teacher_beta |
|
|
ent = F.softmax(r_hat, dim=-1)[:, 1] |
|
|
labels = torch.bernoulli(ent).int().numpy().reshape(-1, 1) |
|
|
else: |
|
|
labels = rational_labels |
|
|
|
|
|
|
|
|
len_labels = labels.shape[0] |
|
|
rand_num = np.random.rand(len_labels) |
|
|
noise_index = rand_num <= self.teacher_eps_mistake |
|
|
labels[noise_index] = 1 - labels[noise_index] |
|
|
|
|
|
|
|
|
labels[margin_index] = -1 |
|
|
|
|
|
if self.vlm_label: |
|
|
ts = time.time() |
|
|
time_string = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d-%H-%M-%S') |
|
|
|
|
|
gpt_two_image_paths = [] |
|
|
combined_images_list = [] |
|
|
useful_indices = [] |
|
|
|
|
|
file_path = os.path.abspath(__file__) |
|
|
dir_path = os.path.dirname(file_path) |
|
|
save_path = "{}/data/gpt_query_image/{}/{}".format(dir_path, self.env_name, time_string) |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path) |
|
|
|
|
|
for idx, (img1, img2) in enumerate(zip(img_t_1, img_t_2)): |
|
|
combined_image = np.concatenate([img1, img2], axis=1) |
|
|
combined_images_list.append(combined_image) |
|
|
combined_image = Image.fromarray(combined_image) |
|
|
|
|
|
first_image_save_path = os.path.join(save_path, "first_{:06}.png".format(idx)) |
|
|
second_image_save_path = os.path.join(save_path, "second_{:06}.png".format(idx)) |
|
|
Image.fromarray(img1).save(first_image_save_path) |
|
|
Image.fromarray(img2).save(second_image_save_path) |
|
|
gpt_two_image_paths.append([first_image_save_path, second_image_save_path]) |
|
|
|
|
|
|
|
|
diff = np.linalg.norm(img1 - img2) |
|
|
if diff < 1e-3: |
|
|
useful_indices.append(0) |
|
|
else: |
|
|
useful_indices.append(1) |
|
|
|
|
|
if self.vlm == 'gpt4v_two_image': |
|
|
from vlms.gpt4_infer import gpt4v_infer_2 |
|
|
vlm_labels = [] |
|
|
for idx, (img_path_1, img_path_2) in enumerate(gpt_two_image_paths): |
|
|
print("querying vlm {}/{}".format(idx, len(gpt_two_image_paths))) |
|
|
query_prompt = gpt_free_query_env_prompts[self.env_name] |
|
|
summary_prompt = gpt_summary_env_prompts[self.env_name] |
|
|
res = gpt4v_infer_2(query_prompt, summary_prompt, img_path_1, img_path_2) |
|
|
try: |
|
|
label_res = int(res) |
|
|
except: |
|
|
label_res = -1 |
|
|
|
|
|
vlm_labels.append(label_res) |
|
|
time.sleep(0.1) |
|
|
elif self.vlm == 'gemini_single_prompt': |
|
|
vlm_labels = [] |
|
|
for idx, (img1, img2) in enumerate(zip(img_t_1, img_t_2)): |
|
|
res = gemini_query_1([ |
|
|
gemini_free_query_prompt1, |
|
|
Image.fromarray(img1), |
|
|
gemini_free_query_prompt2, |
|
|
Image.fromarray(img2), |
|
|
gemini_single_query_env_prompts[self.env_name], |
|
|
]) |
|
|
try: |
|
|
if "-1" in res: |
|
|
res = -1 |
|
|
elif "0" in res: |
|
|
res = 0 |
|
|
elif "1" in res: |
|
|
res = 1 |
|
|
else: |
|
|
res = -1 |
|
|
except: |
|
|
res = -1 |
|
|
vlm_labels.append(res) |
|
|
elif self.vlm == "gemini_free_form": |
|
|
vlm_labels = [] |
|
|
for idx, (img1, img2) in enumerate(zip(img_t_1, img_t_2)): |
|
|
res = gemini_query_2( |
|
|
[ |
|
|
gemini_free_query_prompt1, |
|
|
Image.fromarray(img1), |
|
|
gemini_free_query_prompt2, |
|
|
Image.fromarray(img2), |
|
|
gemini_free_query_env_prompts[self.env_name] |
|
|
], |
|
|
gemini_summary_env_prompts[self.env_name] |
|
|
) |
|
|
try: |
|
|
res = int(res) |
|
|
if res not in [0, 1, -1]: |
|
|
res = -1 |
|
|
except: |
|
|
res = -1 |
|
|
vlm_labels.append(res) |
|
|
|
|
|
vlm_labels = np.array(vlm_labels).reshape(-1, 1) |
|
|
good_idx = (vlm_labels != -1).flatten() |
|
|
useful_indices = (np.array(useful_indices) == 1).flatten() |
|
|
good_idx = np.logical_and(good_idx, useful_indices) |
|
|
|
|
|
sa_t_1 = sa_t_1[good_idx] |
|
|
sa_t_2 = sa_t_2[good_idx] |
|
|
r_t_1 = r_t_1[good_idx] |
|
|
r_t_2 = r_t_2[good_idx] |
|
|
rational_labels = rational_labels[good_idx] |
|
|
vlm_labels = vlm_labels[good_idx] |
|
|
combined_images_list = np.array(combined_images_list)[good_idx] |
|
|
img_t_1 = img_t_1[good_idx] |
|
|
img_t_2 = img_t_2[good_idx] |
|
|
if self.flip_vlm_label: |
|
|
vlm_labels = 1 - vlm_labels |
|
|
|
|
|
if self.train_times % self.save_query_interval == 0 or 'gpt4v' in self.vlm: |
|
|
save_path = os.path.join(self.log_dir, "vlm_label_set") |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path) |
|
|
with open("{}/{}.pkl".format(save_path, time_string), "wb") as f: |
|
|
pkl.dump([combined_images_list, rational_labels, vlm_labels, sa_t_1, sa_t_2, r_t_1, r_t_2], f, protocol=pkl.HIGHEST_PROTOCOL) |
|
|
|
|
|
acc = 0 |
|
|
if len(vlm_labels) > 0: |
|
|
acc = np.sum(vlm_labels == rational_labels) / len(vlm_labels) |
|
|
print("vlm label acc: {}".format(acc)) |
|
|
print("vlm label acc: {}".format(acc)) |
|
|
print("vlm label acc: {}".format(acc)) |
|
|
else: |
|
|
print("no vlm label") |
|
|
print("no vlm label") |
|
|
print("no vlm label") |
|
|
|
|
|
self.vlm_label_acc = acc |
|
|
if not self.image_reward: |
|
|
return sa_t_1, sa_t_2, r_t_1, r_t_2, labels, vlm_labels |
|
|
else: |
|
|
return sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, labels, vlm_labels |
|
|
|
|
|
if not self.image_reward: |
|
|
return sa_t_1, sa_t_2, r_t_1, r_t_2, labels |
|
|
else: |
|
|
return sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, labels |
|
|
|
|
|
def kcenter_sampling(self): |
|
|
|
|
|
|
|
|
num_init = self.mb_size*self.large_batch |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries( |
|
|
mb_size=num_init) |
|
|
|
|
|
|
|
|
temp_sa_t_1 = sa_t_1[:,:,:self.ds] |
|
|
temp_sa_t_2 = sa_t_2[:,:,:self.ds] |
|
|
temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init, -1), |
|
|
temp_sa_t_2.reshape(num_init, -1)], axis=1) |
|
|
|
|
|
max_len = self.capacity if self.buffer_full else self.buffer_index |
|
|
|
|
|
tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds] |
|
|
tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds] |
|
|
tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1), |
|
|
tot_sa_2.reshape(max_len, -1)], axis=1) |
|
|
|
|
|
selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size) |
|
|
|
|
|
r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index] |
|
|
r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index] |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2) |
|
|
|
|
|
if len(labels) > 0: |
|
|
self.put_queries(sa_t_1, sa_t_2, labels) |
|
|
|
|
|
return len(labels) |
|
|
|
|
|
def kcenter_disagree_sampling(self): |
|
|
|
|
|
num_init = self.mb_size*self.large_batch |
|
|
num_init_half = int(num_init*0.5) |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries( |
|
|
mb_size=num_init) |
|
|
|
|
|
|
|
|
_, disagree = self.get_rank_probability(sa_t_1, sa_t_2) |
|
|
top_k_index = (-disagree).argsort()[:num_init_half] |
|
|
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index] |
|
|
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index] |
|
|
|
|
|
|
|
|
temp_sa_t_1 = sa_t_1[:,:,:self.ds] |
|
|
temp_sa_t_2 = sa_t_2[:,:,:self.ds] |
|
|
|
|
|
temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init_half, -1), |
|
|
temp_sa_t_2.reshape(num_init_half, -1)], axis=1) |
|
|
|
|
|
max_len = self.capacity if self.buffer_full else self.buffer_index |
|
|
|
|
|
tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds] |
|
|
tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds] |
|
|
tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1), |
|
|
tot_sa_2.reshape(max_len, -1)], axis=1) |
|
|
|
|
|
selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size) |
|
|
|
|
|
r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index] |
|
|
r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index] |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2) |
|
|
|
|
|
if len(labels) > 0: |
|
|
self.put_queries(sa_t_1, sa_t_2, labels) |
|
|
|
|
|
return len(labels) |
|
|
|
|
|
def kcenter_entropy_sampling(self): |
|
|
|
|
|
num_init = self.mb_size*self.large_batch |
|
|
num_init_half = int(num_init*0.5) |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries( |
|
|
mb_size=num_init) |
|
|
|
|
|
|
|
|
|
|
|
entropy, _ = self.get_entropy(sa_t_1, sa_t_2) |
|
|
top_k_index = (-entropy).argsort()[:num_init_half] |
|
|
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index] |
|
|
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index] |
|
|
|
|
|
|
|
|
temp_sa_t_1 = sa_t_1[:,:,:self.ds] |
|
|
temp_sa_t_2 = sa_t_2[:,:,:self.ds] |
|
|
|
|
|
temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init_half, -1), |
|
|
temp_sa_t_2.reshape(num_init_half, -1)], axis=1) |
|
|
|
|
|
max_len = self.capacity if self.buffer_full else self.buffer_index |
|
|
|
|
|
tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds] |
|
|
tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds] |
|
|
tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1), |
|
|
tot_sa_2.reshape(max_len, -1)], axis=1) |
|
|
|
|
|
selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size) |
|
|
|
|
|
r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index] |
|
|
r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index] |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2) |
|
|
|
|
|
if len(labels) > 0: |
|
|
self.put_queries(sa_t_1, sa_t_2, labels) |
|
|
|
|
|
return len(labels) |
|
|
|
|
|
def uniform_sampling(self): |
|
|
if not self.vlm_label: |
|
|
|
|
|
if not self.image_reward: |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries( |
|
|
mb_size=self.mb_size) |
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2) |
|
|
else: |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2 = self.get_queries( |
|
|
mb_size=self.mb_size) |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2) |
|
|
else: |
|
|
if self.cached_label_path is None: |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2 = self.get_queries( |
|
|
mb_size=self.mb_size) |
|
|
if not self.image_reward: |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, gt_labels, vlm_labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2) |
|
|
else: |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, gt_labels, vlm_labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2) |
|
|
else: |
|
|
if self.read_cache_idx < len(self.all_cached_labels): |
|
|
combined_images_list, sa_t_1, sa_t_2, r_t_1, r_t_2, gt_labels, vlm_labels = self.get_label_from_cached_states() |
|
|
if self.image_reward: |
|
|
num, height, width, _ = combined_images_list.shape |
|
|
img_t_1 = combined_images_list[:, :, :width//2, :] |
|
|
img_t_2 = combined_images_list[:, :, width//2:, :] |
|
|
if 'Rope' not in self.env_name and \ |
|
|
'Water' not in self.env_name: |
|
|
resized_img_t_1 = np.zeros((num, self.image_height, self.image_width, 3), dtype=np.uint8) |
|
|
resized_img_t_2 = np.zeros((num, self.image_height, self.image_width, 3), dtype=np.uint8) |
|
|
for idx in range(len(img_t_1)): |
|
|
resized_img_t_1[idx] = cv2.resize(img_t_1[idx], (self.image_height, self.image_width)) |
|
|
resized_img_t_2[idx] = cv2.resize(img_t_2[idx], (self.image_height, self.image_width)) |
|
|
img_t_1 = resized_img_t_1 |
|
|
img_t_2 = resized_img_t_2 |
|
|
else: |
|
|
vlm_labels = [] |
|
|
|
|
|
labels = vlm_labels |
|
|
|
|
|
if len(labels) > 0: |
|
|
if not self.image_reward: |
|
|
self.put_queries(sa_t_1, sa_t_2, labels) |
|
|
else: |
|
|
self.put_queries(img_t_1[:, ::self.resize_factor, ::self.resize_factor, :], img_t_2[:, ::self.resize_factor, ::self.resize_factor, :], labels) |
|
|
|
|
|
return len(labels) |
|
|
|
|
|
def get_label_from_cached_states(self): |
|
|
if self.read_cache_idx >= len(self.all_cached_labels): |
|
|
return None, None, None, None, None, [] |
|
|
with open(self.all_cached_labels[self.read_cache_idx], 'rb') as f: |
|
|
data = pkl.load(f) |
|
|
combined_images_list, rational_labels, vlm_labels, sa_t_1, sa_t_2, r_t_1, r_t_2 = data |
|
|
self.read_cache_idx += 1 |
|
|
return combined_images_list, sa_t_1, sa_t_2, r_t_1, r_t_2, rational_labels, vlm_labels |
|
|
|
|
|
def disagreement_sampling(self): |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries( |
|
|
mb_size=self.mb_size*self.large_batch) |
|
|
|
|
|
|
|
|
_, disagree = self.get_rank_probability(sa_t_1, sa_t_2) |
|
|
top_k_index = (-disagree).argsort()[:self.mb_size] |
|
|
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index] |
|
|
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index] |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2) |
|
|
if len(labels) > 0: |
|
|
self.put_queries(sa_t_1, sa_t_2, labels) |
|
|
|
|
|
return len(labels) |
|
|
|
|
|
def entropy_sampling(self): |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries( |
|
|
mb_size=self.mb_size*self.large_batch) |
|
|
|
|
|
|
|
|
entropy, _ = self.get_entropy(sa_t_1, sa_t_2) |
|
|
|
|
|
top_k_index = (-entropy).argsort()[:self.mb_size] |
|
|
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index] |
|
|
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index] |
|
|
|
|
|
|
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label( |
|
|
sa_t_1, sa_t_2, r_t_1, r_t_2) |
|
|
|
|
|
if len(labels) > 0: |
|
|
self.put_queries(sa_t_1, sa_t_2, labels) |
|
|
|
|
|
return len(labels) |
|
|
|
|
|
def train_reward(self): |
|
|
self.train_times += 1 |
|
|
|
|
|
ensemble_losses = [[] for _ in range(self.de)] |
|
|
ensemble_acc = np.array([0 for _ in range(self.de)]) |
|
|
|
|
|
max_len = self.capacity if self.buffer_full else self.buffer_index |
|
|
total_batch_index = [] |
|
|
for _ in range(self.de): |
|
|
total_batch_index.append(np.random.permutation(max_len)) |
|
|
|
|
|
num_epochs = int(np.ceil(max_len/self.train_batch_size)) |
|
|
total = 0 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
self.opt.zero_grad() |
|
|
loss = 0.0 |
|
|
|
|
|
last_index = (epoch+1)*self.train_batch_size |
|
|
if last_index > max_len: |
|
|
last_index = max_len |
|
|
|
|
|
for member in range(self.de): |
|
|
|
|
|
|
|
|
idxs = total_batch_index[member][epoch*self.train_batch_size:last_index] |
|
|
sa_t_1 = self.buffer_seg1[idxs] |
|
|
sa_t_2 = self.buffer_seg2[idxs] |
|
|
labels = self.buffer_label[idxs] |
|
|
labels = torch.from_numpy(labels.flatten()).long().to(device) |
|
|
|
|
|
if member == 0: |
|
|
total += labels.size(0) |
|
|
|
|
|
if self.image_reward: |
|
|
|
|
|
sa_t_1 = np.transpose(sa_t_1, (0, 1, 4, 2, 3)) |
|
|
sa_t_2 = np.transpose(sa_t_2, (0, 1, 4, 2, 3)) |
|
|
|
|
|
sa_t_1 = sa_t_1.astype(np.float32) / 255.0 |
|
|
sa_t_2 = sa_t_2.astype(np.float32) / 255.0 |
|
|
sa_t_1 = sa_t_1.squeeze(1) |
|
|
sa_t_2 = sa_t_2.squeeze(1) |
|
|
|
|
|
|
|
|
r_hat1 = self.r_hat_member(sa_t_1, member=member) |
|
|
r_hat2 = self.r_hat_member(sa_t_2, member=member) |
|
|
if not self.image_reward: |
|
|
r_hat1 = r_hat1.sum(axis=1) |
|
|
r_hat2 = r_hat2.sum(axis=1) |
|
|
r_hat = torch.cat([r_hat1, r_hat2], axis=-1) |
|
|
|
|
|
|
|
|
curr_loss = self.CEloss(r_hat, labels) |
|
|
loss += curr_loss |
|
|
ensemble_losses[member].append(curr_loss.item()) |
|
|
|
|
|
|
|
|
_, predicted = torch.max(r_hat.data, 1) |
|
|
correct = (predicted == labels).sum().item() |
|
|
ensemble_acc[member] += correct |
|
|
|
|
|
loss.backward() |
|
|
self.opt.step() |
|
|
|
|
|
ensemble_acc = ensemble_acc / total |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return ensemble_acc |
|
|
|
|
|
def train_soft_reward(self): |
|
|
ensemble_losses = [[] for _ in range(self.de)] |
|
|
ensemble_acc = np.array([0 for _ in range(self.de)]) |
|
|
|
|
|
max_len = self.capacity if self.buffer_full else self.buffer_index |
|
|
total_batch_index = [] |
|
|
for _ in range(self.de): |
|
|
total_batch_index.append(np.random.permutation(max_len)) |
|
|
|
|
|
num_epochs = int(np.ceil(max_len/self.train_batch_size)) |
|
|
list_debug_loss1, list_debug_loss2 = [], [] |
|
|
total = 0 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
self.opt.zero_grad() |
|
|
loss = 0.0 |
|
|
|
|
|
last_index = (epoch+1)*self.train_batch_size |
|
|
if last_index > max_len: |
|
|
last_index = max_len |
|
|
|
|
|
for member in range(self.de): |
|
|
|
|
|
|
|
|
idxs = total_batch_index[member][epoch*self.train_batch_size:last_index] |
|
|
sa_t_1 = self.buffer_seg1[idxs] |
|
|
sa_t_2 = self.buffer_seg2[idxs] |
|
|
labels = self.buffer_label[idxs] |
|
|
labels = torch.from_numpy(labels.flatten()).long().to(device) |
|
|
|
|
|
if member == 0: |
|
|
total += labels.size(0) |
|
|
|
|
|
|
|
|
r_hat1 = self.r_hat_member(sa_t_1, member=member) |
|
|
r_hat2 = self.r_hat_member(sa_t_2, member=member) |
|
|
r_hat1 = r_hat1.sum(axis=1) |
|
|
r_hat2 = r_hat2.sum(axis=1) |
|
|
r_hat = torch.cat([r_hat1, r_hat2], axis=-1) |
|
|
|
|
|
|
|
|
uniform_index = labels == -1 |
|
|
labels[uniform_index] = 0 |
|
|
target_onehot = torch.zeros_like(r_hat).scatter(1, labels.unsqueeze(1), self.label_target) |
|
|
target_onehot += self.label_margin |
|
|
if sum(uniform_index) > 0: |
|
|
target_onehot[uniform_index] = 0.5 |
|
|
curr_loss = self.softXEnt_loss(r_hat, target_onehot) |
|
|
loss += curr_loss |
|
|
ensemble_losses[member].append(curr_loss.item()) |
|
|
|
|
|
|
|
|
_, predicted = torch.max(r_hat.data, 1) |
|
|
correct = (predicted == labels).sum().item() |
|
|
ensemble_acc[member] += correct |
|
|
|
|
|
loss.backward() |
|
|
self.opt.step() |
|
|
|
|
|
ensemble_acc = ensemble_acc / total |
|
|
|
|
|
return ensemble_acc |