basketball_code / reward_model.py
youqiwong's picture
Upload folder using huggingface_hub
0c51b93 verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import asyncio
from PIL import Image
import datetime
import pickle as pkl
import random
import cv2
from prompt import (
gemini_free_query_env_prompts, gemini_summary_env_prompts,
gemini_free_query_prompt1, gemini_free_query_prompt2,
gemini_single_query_env_prompts,
gpt_free_query_env_prompts, gpt_summary_env_prompts,
)
from vlms.gemini_infer import gemini_query_2, gemini_query_1
from conv_net import CNN, fanin_init
device = 'cuda'
def gen_net(in_size=1, out_size=1, H=128, n_layers=3, activation='tanh'):
net = []
for i in range(n_layers):
net.append(nn.Linear(in_size, H))
net.append(nn.LeakyReLU())
in_size = H
net.append(nn.Linear(in_size, out_size))
if activation == 'tanh':
net.append(nn.Tanh())
elif activation == 'sig':
net.append(nn.Sigmoid())
else:
net.append(nn.ReLU())
return net
def gen_image_net(image_height, image_width,
conv_kernel_sizes=[5, 3, 3 ,3],
conv_n_channels=[16, 32, 64, 128],
conv_strides=[3, 2, 2, 2]):
conv_args=dict( # conv layers
kernel_sizes=conv_kernel_sizes, # for sweep into, cartpole, drawer open.
n_channels=conv_n_channels,
strides=conv_strides,
output_size=1,
)
conv_kwargs=dict(
hidden_sizes=[], # linear layers after conv
batch_norm_conv=False,
batch_norm_fc=False,
)
return CNN(
**conv_args,
paddings=np.zeros(len(conv_args['kernel_sizes']), dtype=np.int64),
input_height=image_height,
input_width=image_width,
input_channels=3,
init_w=1e-3,
hidden_init=fanin_init,
**conv_kwargs
)
def gen_image_net2():
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=1)
return model
def KCenterGreedy(obs, full_obs, num_new_sample):
selected_index = []
current_index = list(range(obs.shape[0]))
new_obs = obs
new_full_obs = full_obs
start_time = time.time()
for count in range(num_new_sample):
dist = compute_smallest_dist(new_obs, new_full_obs)
max_index = torch.argmax(dist)
max_index = max_index.item()
if count == 0:
selected_index.append(max_index)
else:
selected_index.append(current_index[max_index])
current_index = current_index[0:max_index] + current_index[max_index+1:]
new_obs = obs[current_index]
new_full_obs = np.concatenate([
full_obs,
obs[selected_index]],
axis=0)
return selected_index
def compute_smallest_dist(obs, full_obs):
obs = torch.from_numpy(obs).float()
full_obs = torch.from_numpy(full_obs).float()
batch_size = 100
with torch.no_grad():
total_dists = []
for full_idx in range(len(obs) // batch_size + 1):
full_start = full_idx * batch_size
if full_start < len(obs):
full_end = (full_idx + 1) * batch_size
dists = []
for idx in range(len(full_obs) // batch_size + 1):
start = idx * batch_size
if start < len(full_obs):
end = (idx + 1) * batch_size
dist = torch.norm(
obs[full_start:full_end, None, :].to(device) - full_obs[None, start:end, :].to(device), dim=-1, p=2
)
dists.append(dist)
dists = torch.cat(dists, dim=1)
small_dists = torch.torch.min(dists, dim=1).values
total_dists.append(small_dists)
total_dists = torch.cat(total_dists)
return total_dists.unsqueeze(1)
class RewardModel:
def __init__(self, ds, da,
ensemble_size=3, lr=3e-4, mb_size = 128, size_segment=1,
max_size=100, activation='tanh', capacity=5e5,
large_batch=1, label_margin=0.0,
teacher_beta=-1, teacher_gamma=1,
teacher_eps_mistake=0,
teacher_eps_skip=0,
teacher_eps_equal=0,
# vlm related params
vlm_label=True,
env_name="CartPole-v1",
vlm="gemini_free_form",
clip_prompt=None,
log_dir=None,
flip_vlm_label=False,
save_query_interval=25,
cached_label_path=None,
# image based reward
reward_model_layers=3,
reward_model_H=256,
image_reward=True,
image_height=128,
image_width=128,
resize_factor=1,
resnet=False,
conv_kernel_sizes=[5, 3, 3 ,3],
conv_n_channels=[16, 32, 64, 128],
conv_strides=[3, 2, 2, 2],
**kwargs
):
# train data is trajectories, must process to sa and s..
self.ds = ds
self.da = da
self.de = ensemble_size
self.lr = lr
self.ensemble = []
self.paramlst = []
self.opt = None
self.model = None
self.max_size = max_size
self.activation = activation
self.size_segment = size_segment
self.capacity = int(capacity)
self.reward_model_layers = reward_model_layers
self.reward_model_H = reward_model_H
self.image_reward = image_reward
self.resnet = resnet
self.conv_kernel_sizes = conv_kernel_sizes
self.conv_n_channels = conv_n_channels
self.conv_strides = conv_strides
if not image_reward:
self.buffer_seg1 = np.empty((self.capacity, size_segment, self.ds+self.da), dtype=np.float32)
self.buffer_seg2 = np.empty((self.capacity, size_segment, self.ds+self.da), dtype=np.float32)
else:
assert self.size_segment == 1
self.buffer_seg1 = np.empty((self.capacity, 1, image_height, image_width, 3), dtype=np.uint8)
self.buffer_seg2 = np.empty((self.capacity, 1, image_height, image_width, 3), dtype=np.uint8)
self.image_height = image_height
self.image_width = image_width
self.resize_factor = resize_factor
self.buffer_label = np.empty((self.capacity, 1), dtype=np.float32)
self.buffer_index = 0
self.buffer_full = False
self.construct_ensemble()
self.inputs = []
self.targets = []
self.raw_actions = []
self.img_inputs = []
self.mb_size = mb_size
self.origin_mb_size = mb_size
if not image_reward:
self.train_batch_size = 128
else:
if not self.resnet:
self.train_batch_size = 64
else:
self.train_batch_size = 32
self.CEloss = nn.CrossEntropyLoss()
self.running_means = []
self.running_stds = []
self.best_seg = []
self.best_label = []
self.best_action = []
self.large_batch = large_batch
# new teacher
self.teacher_beta = teacher_beta
self.teacher_gamma = teacher_gamma
self.teacher_eps_mistake = teacher_eps_mistake
self.teacher_eps_equal = teacher_eps_equal
self.teacher_eps_skip = teacher_eps_skip
self.teacher_thres_skip = 0
self.teacher_thres_equal = 0
self.label_margin = label_margin
self.label_target = 1 - 2*self.label_margin
# vlm label
self.vlm_label = vlm_label
self.env_name = env_name
self.vlm = vlm
self.clip_prompt = clip_prompt
self.vlm_label_acc = 0
self.log_dir = log_dir
self.flip_vlm_label = flip_vlm_label
self.train_times = 0
self.save_query_interval = save_query_interval
file_path = os.path.abspath(__file__)
dir_path = os.path.dirname(file_path)
self.cached_label_path = "{}/{}".format(dir_path, cached_label_path)
self.read_cache_idx = 0
if self.cached_label_path is not None:
all_cached_labels = sorted(os.listdir(self.cached_label_path))
self.all_cached_labels = [os.path.join(self.cached_label_path, x) for x in all_cached_labels]
def eval(self,):
for i in range(self.de):
self.ensemble[i].eval()
def train(self,):
for i in range(self.de):
self.ensemble[i].train()
def softXEnt_loss(self, input, target):
logprobs = torch.nn.functional.log_softmax (input, dim = 1)
return -(target * logprobs).sum() / input.shape[0]
def change_batch(self, new_frac):
self.mb_size = int(self.origin_mb_size*new_frac)
def set_batch(self, new_batch):
self.mb_size = int(new_batch)
def set_teacher_thres_skip(self, new_margin):
self.teacher_thres_skip = new_margin * self.teacher_eps_skip
def set_teacher_thres_equal(self, new_margin):
self.teacher_thres_equal = new_margin * self.teacher_eps_equal
def construct_ensemble(self):
for i in range(self.de):
if not self.image_reward:
model = nn.Sequential(*gen_net(in_size=self.ds+self.da,
out_size=1, H=self.reward_model_H, n_layers=self.reward_model_layers,
activation=self.activation)).float().to(device)
else:
if not self.resnet:
model = gen_image_net(self.image_height, self.image_width, self.conv_kernel_sizes, self.conv_n_channels, self.conv_strides).float().to(device)
else:
model = gen_image_net2().float().to(device)
self.ensemble.append(model)
self.paramlst.extend(model.parameters())
self.opt = torch.optim.Adam(self.paramlst, lr = self.lr)
def add_data(self, obs, act, rew, done, img=None):
sa_t = np.concatenate([obs, act], axis=-1)
r_t = rew
flat_input = sa_t.reshape(1, self.da+self.ds)
r_t = np.array(r_t)
flat_target = r_t.reshape(1, 1)
if img is not None:
flat_img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2])
init_data = len(self.inputs) == 0
if init_data:
self.inputs.append(flat_input)
self.targets.append(flat_target)
if img is not None:
self.img_inputs.append(flat_img)
elif done:
if 'Cloth' not in self.env_name:
self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input])
self.targets[-1] = np.concatenate([self.targets[-1], flat_target])
if img is not None:
self.img_inputs[-1] = np.concatenate([self.img_inputs[-1], flat_img], axis=0)
# FIFO
if len(self.inputs) > self.max_size:
self.inputs = self.inputs[1:]
self.targets = self.targets[1:]
if img is not None:
self.img_inputs = self.img_inputs[1:]
self.inputs.append([])
self.targets.append([])
if img is not None:
self.img_inputs.append([])
else: # clothfold env has is only a 1 step MDP
self.inputs.append([flat_input])
self.targets.append([flat_target])
if img is not None:
self.img_inputs.append([flat_img])
# FIFO
if len(self.inputs) > self.max_size:
self.inputs = self.inputs[1:]
self.targets = self.targets[1:]
if img is not None:
self.img_inputs = self.img_inputs[1:]
else:
if len(self.inputs[-1]) == 0:
self.inputs[-1] = flat_input
self.targets[-1] = flat_target
if img is not None:
self.img_inputs[-1] = flat_img
else:
self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input])
self.targets[-1] = np.concatenate([self.targets[-1], flat_target])
if img is not None:
self.img_inputs[-1] = np.concatenate([self.img_inputs[-1], flat_img], axis=0)
def add_data_batch(self, obses, rewards):
num_env = obses.shape[0]
for index in range(num_env):
self.inputs.append(obses[index])
self.targets.append(rewards[index])
def get_rank_probability(self, x_1, x_2):
# get probability x_1 > x_2
probs = []
for member in range(self.de):
probs.append(self.p_hat_member(x_1, x_2, member=member).cpu().numpy())
probs = np.array(probs)
return np.mean(probs, axis=0), np.std(probs, axis=0)
def get_entropy(self, x_1, x_2):
# get probability x_1 > x_2
probs = []
for member in range(self.de):
probs.append(self.p_hat_entropy(x_1, x_2, member=member).cpu().numpy())
probs = np.array(probs)
return np.mean(probs, axis=0), np.std(probs, axis=0)
def p_hat_member(self, x_1, x_2, member=-1):
# softmaxing to get the probabilities according to eqn 1
with torch.no_grad():
r_hat1 = self.r_hat_member(x_1, member=member)
r_hat2 = self.r_hat_member(x_2, member=member)
r_hat1 = r_hat1.sum(axis=1)
r_hat2 = r_hat2.sum(axis=1)
r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
# taking 0 index for probability x_1 > x_2
return F.softmax(r_hat, dim=-1)[:,0]
def p_hat_entropy(self, x_1, x_2, member=-1):
# softmaxing to get the probabilities according to eqn 1
with torch.no_grad():
r_hat1 = self.r_hat_member(x_1, member=member)
r_hat2 = self.r_hat_member(x_2, member=member)
r_hat1 = r_hat1.sum(axis=1)
r_hat2 = r_hat2.sum(axis=1)
r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
ent = F.softmax(r_hat, dim=-1) * F.log_softmax(r_hat, dim=-1)
ent = ent.sum(axis=-1).abs()
return ent
def r_hat_member(self, x, member=-1):
# the network parameterizes r hat in eqn 1 from the paper
return self.ensemble[member](torch.from_numpy(x).float().to(device))
def r_hat(self, x):
# they say they average the rewards from each member of the ensemble, but I think this only makes sense if the rewards are already normalized
# but I don't understand how the normalization should be happening right now :(
r_hats = []
for member in range(self.de):
r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
r_hats = np.array(r_hats)
return np.mean(r_hats)
def r_hat_batch(self, x):
# they say they average the rewards from each member of the ensemble, but I think this only makes sense if the rewards are already normalized
# but I don't understand how the normalization should be happening right now :(
r_hats = []
for member in range(self.de):
r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
r_hats = np.array(r_hats)
return np.mean(r_hats, axis=0)
def save(self, model_dir, step):
for member in range(self.de):
torch.save(
self.ensemble[member].state_dict(), '%s/reward_model_%s_%s.pt' % (model_dir, step, member)
)
def load(self, model_dir, step):
file_dir = os.path.dirname(os.path.realpath(__file__))
model_dir = os.path.join(file_dir, model_dir)
for member in range(self.de):
self.ensemble[member].load_state_dict(
torch.load('%s/reward_model_%s_%s.pt' % (model_dir, step, member))
)
def get_train_acc(self):
ensemble_acc = np.array([0 for _ in range(self.de)])
max_len = self.capacity if self.buffer_full else self.buffer_index
total_batch_index = np.random.permutation(max_len)
batch_size = 256
num_epochs = int(np.ceil(max_len/batch_size))
total = 0
for epoch in range(num_epochs):
last_index = (epoch+1)*batch_size
if (epoch+1)*batch_size > max_len:
last_index = max_len
sa_t_1 = self.buffer_seg1[epoch*batch_size:last_index]
sa_t_2 = self.buffer_seg2[epoch*batch_size:last_index]
labels = self.buffer_label[epoch*batch_size:last_index]
labels = torch.from_numpy(labels.flatten()).long().to(device)
total += labels.size(0)
for member in range(self.de):
# get logits
r_hat1 = self.r_hat_member(sa_t_1, member=member)
r_hat2 = self.r_hat_member(sa_t_2, member=member)
r_hat1 = r_hat1.sum(axis=1)
r_hat2 = r_hat2.sum(axis=1)
r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
_, predicted = torch.max(r_hat.data, 1)
correct = (predicted == labels).sum().item()
ensemble_acc[member] += correct
ensemble_acc = ensemble_acc / total
return np.mean(ensemble_acc)
def get_queries(self, mb_size=20):
len_traj, max_len = len(self.inputs[0]), len(self.inputs)
if len(self.inputs[-1]) < len_traj:
max_len = max_len - 1
# get train traj
train_inputs = np.array(self.inputs[:max_len])
train_targets = np.array(self.targets[:max_len])
if self.vlm_label or self.image_reward:
train_images = np.array(self.img_inputs[:max_len])
if 'Cloth' in self.env_name:
train_images = train_images.squeeze(1)
batch_index_2 = np.random.choice(max_len, size=mb_size, replace=True)
sa_t_2 = train_inputs[batch_index_2] # Batch x T x dim of s&a
r_t_2 = train_targets[batch_index_2] # Batch x T x 1
if self.vlm_label or self.image_reward:
img_t_2 = train_images[batch_index_2] # Batch x T x *img_dim
batch_index_1 = np.random.choice(max_len, size=mb_size, replace=True)
sa_t_1 = train_inputs[batch_index_1] # Batch x T x dim of s&a
r_t_1 = train_targets[batch_index_1] # Batch x T x 1
if self.vlm_label or self.image_reward:
img_t_1 = train_images[batch_index_1] # Batch x T x *img_dim
sa_t_1 = sa_t_1.reshape(-1, sa_t_1.shape[-1]) # (Batch x T) x dim of s&a
r_t_1 = r_t_1.reshape(-1, r_t_1.shape[-1]) # (Batch x T) x 1
sa_t_2 = sa_t_2.reshape(-1, sa_t_2.shape[-1]) # (Batch x T) x dim of s&a
r_t_2 = r_t_2.reshape(-1, r_t_2.shape[-1]) # (Batch x T) x 1
if self.vlm_label or self.image_reward:
img_t_1 = img_t_1.reshape(-1, img_t_1.shape[2], img_t_1.shape[3], img_t_1.shape[4])
img_t_2 = img_t_2.reshape(-1, img_t_2.shape[2], img_t_2.shape[3], img_t_2.shape[4])
# Generate time index
time_index = np.array([list(range(i*len_traj, i*len_traj+self.size_segment)) for i in range(mb_size)])
if 'Cloth' not in self.env_name:
random_idx_2 = np.random.choice(len_traj-self.size_segment, size=mb_size, replace=True).reshape(-1,1)
time_index_2 = time_index + random_idx_2
random_idx_1 = np.random.choice(len_traj-self.size_segment, size=mb_size, replace=True).reshape(-1,1)
time_index_1 = time_index + random_idx_1
else:
time_index_2 = time_index
time_index_1 = time_index
if self.vlm_label or self.image_reward:
if self.vlm_label == 1 or self.image_reward: # use a single image for querying vlm for the labeling
image_time_index = np.array([[i*len_traj+self.size_segment - 1] for i in range(mb_size)])
else:
interval = self.size_segment // self.vlm_label
image_time_index = np.array([[i * len_traj + self.size_segment - 1 - j * interval for j in range(self.vlm_label - 1, -1, -1)] for i in range(mb_size)])
image_time_index = np.maximum(image_time_index, 0)
if 'Cloth' not in self.env_name:
image_time_index_2 = image_time_index + random_idx_2
image_time_index_1 = image_time_index + random_idx_1
else:
image_time_index_2 = image_time_index
image_time_index_1 = image_time_index
sa_t_1 = np.take(sa_t_1, time_index_1, axis=0) # Batch x size_seg x dim of s&a
r_t_1 = np.take(r_t_1, time_index_1, axis=0) # Batch x size_seg x 1
sa_t_2 = np.take(sa_t_2, time_index_2, axis=0) # Batch x size_seg x dim of s&a
r_t_2 = np.take(r_t_2, time_index_2, axis=0) # Batch x size_seg x 1
if self.vlm_label or self.image_reward:
img_t_1 = np.take(img_t_1, image_time_index_1, axis=0) # Batch x vlm_label x *img_dim
img_t_2 = np.take(img_t_2, image_time_index_2, axis=0) # Batch x vlm_label x *img_dim
batch_size, horizon, image_height, image_width, _ = img_t_1.shape
transposed_images = np.transpose(img_t_1, (0, 2, 1, 3, 4))
img_t_1 = transposed_images.reshape(batch_size, image_height, horizon * image_width, 3) # batch x image_height x (time_horizon * image_width) x 3
transposed_images = np.transpose(img_t_2, (0, 2, 1, 3, 4))
img_t_2 = transposed_images.reshape(batch_size, image_height, horizon * image_width, 3) # batch x image_height x (time_horizon * image_width) x 3
if not self.vlm_label and not self.image_reward:
return sa_t_1, sa_t_2, r_t_1, r_t_2
else:
return sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2
def put_queries(self, sa_t_1, sa_t_2, labels):
total_sample = sa_t_1.shape[0]
next_index = self.buffer_index + total_sample
# NOTE: buffer_seg is overloaded. When not using image based rewards, it gives concatenated state action pairs. When image based rewards are used, it gives the images.
if next_index >= self.capacity:
self.buffer_full = True
maximum_index = self.capacity - self.buffer_index
np.copyto(self.buffer_seg1[self.buffer_index:self.capacity], sa_t_1[:maximum_index])
np.copyto(self.buffer_seg2[self.buffer_index:self.capacity], sa_t_2[:maximum_index])
np.copyto(self.buffer_label[self.buffer_index:self.capacity], labels[:maximum_index])
remain = total_sample - (maximum_index)
if remain > 0:
np.copyto(self.buffer_seg1[0:remain], sa_t_1[maximum_index:])
np.copyto(self.buffer_seg2[0:remain], sa_t_2[maximum_index:])
np.copyto(self.buffer_label[0:remain], labels[maximum_index:])
self.buffer_index = remain
else:
if self.image_reward:
sa_t_1 = sa_t_1.reshape(sa_t_1.shape[0], 1, sa_t_1.shape[1], sa_t_1.shape[2], sa_t_1.shape[3])
sa_t_2 = sa_t_2.reshape(sa_t_2.shape[0], 1, sa_t_2.shape[1], sa_t_2.shape[2], sa_t_2.shape[3])
np.copyto(self.buffer_seg1[self.buffer_index:next_index], sa_t_1)
np.copyto(self.buffer_seg2[self.buffer_index:next_index], sa_t_2)
np.copyto(self.buffer_label[self.buffer_index:next_index], labels)
self.buffer_index = next_index
def get_label(self, sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1=None, img_t_2=None):
sum_r_t_1 = np.sum(r_t_1, axis=1)
sum_r_t_2 = np.sum(r_t_2, axis=1)
# skip the query
if self.teacher_thres_skip > 0:
max_r_t = np.maximum(sum_r_t_1, sum_r_t_2)
max_index = (max_r_t > self.teacher_thres_skip).reshape(-1)
if sum(max_index) == 0:
return None, None, None, None, []
sa_t_1 = sa_t_1[max_index]
sa_t_2 = sa_t_2[max_index]
r_t_1 = r_t_1[max_index]
r_t_2 = r_t_2[max_index]
sum_r_t_1 = np.sum(r_t_1, axis=1)
sum_r_t_2 = np.sum(r_t_2, axis=1)
# equally preferable
margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) < self.teacher_thres_equal).reshape(-1)
# perfectly rational
seg_size = r_t_1.shape[1]
temp_r_t_1 = r_t_1.copy()
temp_r_t_2 = r_t_2.copy()
for index in range(seg_size-1):
temp_r_t_1[:,:index+1] *= self.teacher_gamma
temp_r_t_2[:,:index+1] *= self.teacher_gamma
sum_r_t_1 = np.sum(temp_r_t_1, axis=1)
sum_r_t_2 = np.sum(temp_r_t_2, axis=1)
rational_labels = 1*(sum_r_t_1 < sum_r_t_2)
if self.teacher_beta > 0: # Bradley-Terry rational model
r_hat = torch.cat([torch.Tensor(sum_r_t_1),
torch.Tensor(sum_r_t_2)], axis=-1)
r_hat = r_hat*self.teacher_beta
ent = F.softmax(r_hat, dim=-1)[:, 1]
labels = torch.bernoulli(ent).int().numpy().reshape(-1, 1)
else:
labels = rational_labels
# making a mistake
len_labels = labels.shape[0]
rand_num = np.random.rand(len_labels)
noise_index = rand_num <= self.teacher_eps_mistake
labels[noise_index] = 1 - labels[noise_index]
# equally preferable
labels[margin_index] = -1
if self.vlm_label:
ts = time.time()
time_string = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d-%H-%M-%S')
gpt_two_image_paths = []
combined_images_list = []
useful_indices = []
file_path = os.path.abspath(__file__)
dir_path = os.path.dirname(file_path)
save_path = "{}/data/gpt_query_image/{}/{}".format(dir_path, self.env_name, time_string)
if not os.path.exists(save_path):
os.makedirs(save_path)
for idx, (img1, img2) in enumerate(zip(img_t_1, img_t_2)):
combined_image = np.concatenate([img1, img2], axis=1)
combined_images_list.append(combined_image)
combined_image = Image.fromarray(combined_image)
first_image_save_path = os.path.join(save_path, "first_{:06}.png".format(idx))
second_image_save_path = os.path.join(save_path, "second_{:06}.png".format(idx))
Image.fromarray(img1).save(first_image_save_path)
Image.fromarray(img2).save(second_image_save_path)
gpt_two_image_paths.append([first_image_save_path, second_image_save_path])
diff = np.linalg.norm(img1 - img2)
if diff < 1e-3: # ignore the pair if the image is exactly the same
useful_indices.append(0)
else:
useful_indices.append(1)
if self.vlm == 'gpt4v_two_image':
from vlms.gpt4_infer import gpt4v_infer_2
vlm_labels = []
for idx, (img_path_1, img_path_2) in enumerate(gpt_two_image_paths):
print("querying vlm {}/{}".format(idx, len(gpt_two_image_paths)))
query_prompt = gpt_free_query_env_prompts[self.env_name]
summary_prompt = gpt_summary_env_prompts[self.env_name]
res = gpt4v_infer_2(query_prompt, summary_prompt, img_path_1, img_path_2)
try:
label_res = int(res)
except:
label_res = -1
vlm_labels.append(label_res)
time.sleep(0.1)
elif self.vlm == 'gemini_single_prompt':
vlm_labels = []
for idx, (img1, img2) in enumerate(zip(img_t_1, img_t_2)):
res = gemini_query_1([
gemini_free_query_prompt1,
Image.fromarray(img1),
gemini_free_query_prompt2,
Image.fromarray(img2),
gemini_single_query_env_prompts[self.env_name],
])
try:
if "-1" in res:
res = -1
elif "0" in res:
res = 0
elif "1" in res:
res = 1
else:
res = -1
except:
res = -1
vlm_labels.append(res)
elif self.vlm == "gemini_free_form":
vlm_labels = []
for idx, (img1, img2) in enumerate(zip(img_t_1, img_t_2)):
res = gemini_query_2(
[
gemini_free_query_prompt1,
Image.fromarray(img1),
gemini_free_query_prompt2,
Image.fromarray(img2),
gemini_free_query_env_prompts[self.env_name]
],
gemini_summary_env_prompts[self.env_name]
)
try:
res = int(res)
if res not in [0, 1, -1]:
res = -1
except:
res = -1
vlm_labels.append(res)
vlm_labels = np.array(vlm_labels).reshape(-1, 1)
good_idx = (vlm_labels != -1).flatten()
useful_indices = (np.array(useful_indices) == 1).flatten()
good_idx = np.logical_and(good_idx, useful_indices)
sa_t_1 = sa_t_1[good_idx]
sa_t_2 = sa_t_2[good_idx]
r_t_1 = r_t_1[good_idx]
r_t_2 = r_t_2[good_idx]
rational_labels = rational_labels[good_idx]
vlm_labels = vlm_labels[good_idx]
combined_images_list = np.array(combined_images_list)[good_idx]
img_t_1 = img_t_1[good_idx]
img_t_2 = img_t_2[good_idx]
if self.flip_vlm_label:
vlm_labels = 1 - vlm_labels
if self.train_times % self.save_query_interval == 0 or 'gpt4v' in self.vlm:
save_path = os.path.join(self.log_dir, "vlm_label_set")
if not os.path.exists(save_path):
os.makedirs(save_path)
with open("{}/{}.pkl".format(save_path, time_string), "wb") as f:
pkl.dump([combined_images_list, rational_labels, vlm_labels, sa_t_1, sa_t_2, r_t_1, r_t_2], f, protocol=pkl.HIGHEST_PROTOCOL)
acc = 0
if len(vlm_labels) > 0:
acc = np.sum(vlm_labels == rational_labels) / len(vlm_labels)
print("vlm label acc: {}".format(acc))
print("vlm label acc: {}".format(acc))
print("vlm label acc: {}".format(acc))
else:
print("no vlm label")
print("no vlm label")
print("no vlm label")
self.vlm_label_acc = acc
if not self.image_reward:
return sa_t_1, sa_t_2, r_t_1, r_t_2, labels, vlm_labels
else:
return sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, labels, vlm_labels
if not self.image_reward:
return sa_t_1, sa_t_2, r_t_1, r_t_2, labels
else:
return sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, labels
def kcenter_sampling(self):
# get queries
num_init = self.mb_size*self.large_batch
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
mb_size=num_init)
# get final queries based on kmeans clustering
temp_sa_t_1 = sa_t_1[:,:,:self.ds]
temp_sa_t_2 = sa_t_2[:,:,:self.ds]
temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init, -1),
temp_sa_t_2.reshape(num_init, -1)], axis=1)
max_len = self.capacity if self.buffer_full else self.buffer_index
tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds]
tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds]
tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1),
tot_sa_2.reshape(max_len, -1)], axis=1)
selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size)
r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index]
r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index]
# get labels
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2)
if len(labels) > 0:
self.put_queries(sa_t_1, sa_t_2, labels)
return len(labels)
def kcenter_disagree_sampling(self):
num_init = self.mb_size*self.large_batch
num_init_half = int(num_init*0.5)
# get queries
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
mb_size=num_init)
# get final queries based on uncertainty
_, disagree = self.get_rank_probability(sa_t_1, sa_t_2)
top_k_index = (-disagree).argsort()[:num_init_half]
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]
# get final queries based on kmeans clustering
temp_sa_t_1 = sa_t_1[:,:,:self.ds]
temp_sa_t_2 = sa_t_2[:,:,:self.ds]
temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init_half, -1),
temp_sa_t_2.reshape(num_init_half, -1)], axis=1)
max_len = self.capacity if self.buffer_full else self.buffer_index
tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds]
tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds]
tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1),
tot_sa_2.reshape(max_len, -1)], axis=1)
selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size)
r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index]
r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index]
# get labels
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2)
if len(labels) > 0:
self.put_queries(sa_t_1, sa_t_2, labels)
return len(labels)
def kcenter_entropy_sampling(self):
num_init = self.mb_size*self.large_batch
num_init_half = int(num_init*0.5)
# get queries
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
mb_size=num_init)
# get final queries based on uncertainty
entropy, _ = self.get_entropy(sa_t_1, sa_t_2)
top_k_index = (-entropy).argsort()[:num_init_half]
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]
# get final queries based on kmeans clustering
temp_sa_t_1 = sa_t_1[:,:,:self.ds]
temp_sa_t_2 = sa_t_2[:,:,:self.ds]
temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init_half, -1),
temp_sa_t_2.reshape(num_init_half, -1)], axis=1)
max_len = self.capacity if self.buffer_full else self.buffer_index
tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds]
tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds]
tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1),
tot_sa_2.reshape(max_len, -1)], axis=1)
selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size)
r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index]
r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index]
# get labels
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2)
if len(labels) > 0:
self.put_queries(sa_t_1, sa_t_2, labels)
return len(labels)
def uniform_sampling(self):
if not self.vlm_label:
# get queries
if not self.image_reward:
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
mb_size=self.mb_size)
# get labels
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2)
else:
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2 = self.get_queries(
mb_size=self.mb_size)
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2)
else:
if self.cached_label_path is None:
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2 = self.get_queries(
mb_size=self.mb_size)
if not self.image_reward:
sa_t_1, sa_t_2, r_t_1, r_t_2, gt_labels, vlm_labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2)
else:
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2, gt_labels, vlm_labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2, img_t_1, img_t_2)
else:
if self.read_cache_idx < len(self.all_cached_labels):
combined_images_list, sa_t_1, sa_t_2, r_t_1, r_t_2, gt_labels, vlm_labels = self.get_label_from_cached_states()
if self.image_reward:
num, height, width, _ = combined_images_list.shape
img_t_1 = combined_images_list[:, :, :width//2, :]
img_t_2 = combined_images_list[:, :, width//2:, :]
if 'Rope' not in self.env_name and \
'Water' not in self.env_name:
resized_img_t_1 = np.zeros((num, self.image_height, self.image_width, 3), dtype=np.uint8)
resized_img_t_2 = np.zeros((num, self.image_height, self.image_width, 3), dtype=np.uint8)
for idx in range(len(img_t_1)):
resized_img_t_1[idx] = cv2.resize(img_t_1[idx], (self.image_height, self.image_width))
resized_img_t_2[idx] = cv2.resize(img_t_2[idx], (self.image_height, self.image_width))
img_t_1 = resized_img_t_1
img_t_2 = resized_img_t_2
else:
vlm_labels = []
labels = vlm_labels
if len(labels) > 0:
if not self.image_reward:
self.put_queries(sa_t_1, sa_t_2, labels)
else:
self.put_queries(img_t_1[:, ::self.resize_factor, ::self.resize_factor, :], img_t_2[:, ::self.resize_factor, ::self.resize_factor, :], labels)
return len(labels)
def get_label_from_cached_states(self):
if self.read_cache_idx >= len(self.all_cached_labels):
return None, None, None, None, None, []
with open(self.all_cached_labels[self.read_cache_idx], 'rb') as f:
data = pkl.load(f)
combined_images_list, rational_labels, vlm_labels, sa_t_1, sa_t_2, r_t_1, r_t_2 = data
self.read_cache_idx += 1
return combined_images_list, sa_t_1, sa_t_2, r_t_1, r_t_2, rational_labels, vlm_labels
def disagreement_sampling(self):
# get queries
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
mb_size=self.mb_size*self.large_batch)
# get final queries based on uncertainty
_, disagree = self.get_rank_probability(sa_t_1, sa_t_2)
top_k_index = (-disagree).argsort()[:self.mb_size]
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]
# get labels
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2)
if len(labels) > 0:
self.put_queries(sa_t_1, sa_t_2, labels)
return len(labels)
def entropy_sampling(self):
# get queries
sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
mb_size=self.mb_size*self.large_batch)
# get final queries based on uncertainty
entropy, _ = self.get_entropy(sa_t_1, sa_t_2)
top_k_index = (-entropy).argsort()[:self.mb_size]
r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]
# get labels
sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
sa_t_1, sa_t_2, r_t_1, r_t_2)
if len(labels) > 0:
self.put_queries(sa_t_1, sa_t_2, labels)
return len(labels)
def train_reward(self):
self.train_times += 1
ensemble_losses = [[] for _ in range(self.de)]
ensemble_acc = np.array([0 for _ in range(self.de)])
max_len = self.capacity if self.buffer_full else self.buffer_index
total_batch_index = []
for _ in range(self.de):
total_batch_index.append(np.random.permutation(max_len))
num_epochs = int(np.ceil(max_len/self.train_batch_size))
total = 0
for epoch in range(num_epochs):
self.opt.zero_grad()
loss = 0.0
last_index = (epoch+1)*self.train_batch_size
if last_index > max_len:
last_index = max_len
for member in range(self.de):
# get random batch
idxs = total_batch_index[member][epoch*self.train_batch_size:last_index]
sa_t_1 = self.buffer_seg1[idxs]
sa_t_2 = self.buffer_seg2[idxs]
labels = self.buffer_label[idxs]
labels = torch.from_numpy(labels.flatten()).long().to(device)
if member == 0:
total += labels.size(0)
if self.image_reward:
# sa_t_1 is batch_size x segment x image_height x image_width x 3
sa_t_1 = np.transpose(sa_t_1, (0, 1, 4, 2, 3)) # for torch we need to transpose channel first
sa_t_2 = np.transpose(sa_t_2, (0, 1, 4, 2, 3))
# also we stored uint8 images, we need to convert them to float32
sa_t_1 = sa_t_1.astype(np.float32) / 255.0
sa_t_2 = sa_t_2.astype(np.float32) / 255.0
sa_t_1 = sa_t_1.squeeze(1)
sa_t_2 = sa_t_2.squeeze(1)
# get logits
r_hat1 = self.r_hat_member(sa_t_1, member=member)
r_hat2 = self.r_hat_member(sa_t_2, member=member)
if not self.image_reward:
r_hat1 = r_hat1.sum(axis=1)
r_hat2 = r_hat2.sum(axis=1)
r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
# compute loss
curr_loss = self.CEloss(r_hat, labels)
loss += curr_loss
ensemble_losses[member].append(curr_loss.item())
# compute acc
_, predicted = torch.max(r_hat.data, 1)
correct = (predicted == labels).sum().item()
ensemble_acc[member] += correct
loss.backward()
self.opt.step()
ensemble_acc = ensemble_acc / total
torch.cuda.empty_cache()
return ensemble_acc
def train_soft_reward(self):
ensemble_losses = [[] for _ in range(self.de)]
ensemble_acc = np.array([0 for _ in range(self.de)])
max_len = self.capacity if self.buffer_full else self.buffer_index
total_batch_index = []
for _ in range(self.de):
total_batch_index.append(np.random.permutation(max_len))
num_epochs = int(np.ceil(max_len/self.train_batch_size))
list_debug_loss1, list_debug_loss2 = [], []
total = 0
for epoch in range(num_epochs):
self.opt.zero_grad()
loss = 0.0
last_index = (epoch+1)*self.train_batch_size
if last_index > max_len:
last_index = max_len
for member in range(self.de):
# get random batch
idxs = total_batch_index[member][epoch*self.train_batch_size:last_index]
sa_t_1 = self.buffer_seg1[idxs]
sa_t_2 = self.buffer_seg2[idxs]
labels = self.buffer_label[idxs]
labels = torch.from_numpy(labels.flatten()).long().to(device)
if member == 0:
total += labels.size(0)
# get logits
r_hat1 = self.r_hat_member(sa_t_1, member=member)
r_hat2 = self.r_hat_member(sa_t_2, member=member)
r_hat1 = r_hat1.sum(axis=1)
r_hat2 = r_hat2.sum(axis=1)
r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
# compute loss
uniform_index = labels == -1
labels[uniform_index] = 0
target_onehot = torch.zeros_like(r_hat).scatter(1, labels.unsqueeze(1), self.label_target)
target_onehot += self.label_margin
if sum(uniform_index) > 0:
target_onehot[uniform_index] = 0.5
curr_loss = self.softXEnt_loss(r_hat, target_onehot)
loss += curr_loss
ensemble_losses[member].append(curr_loss.item())
# compute acc
_, predicted = torch.max(r_hat.data, 1)
correct = (predicted == labels).sum().item()
ensemble_acc[member] += correct
loss.backward()
self.opt.step()
ensemble_acc = ensemble_acc / total
return ensemble_acc