|
|
import concurrent.futures |
|
|
import random |
|
|
|
|
|
def _compute_single_reward(reward_model, images, input_prompts): |
|
|
"""Compute reward for a single reward model.""" |
|
|
reward_model_name = type(reward_model).__name__ |
|
|
try: |
|
|
if reward_model_name == 'HPSClipRewardModel': |
|
|
rewards = reward_model(images, input_prompts) |
|
|
successes = [1] * len(rewards) |
|
|
|
|
|
elif reward_model_name == 'CLIPScoreRewardModel': |
|
|
rewards = reward_model(input_prompts, images) |
|
|
successes = [1] * len(rewards) |
|
|
|
|
|
elif reward_model_name == 'ImageRewardModel': |
|
|
rewards = reward_model(images, input_prompts) |
|
|
successes = [1] * len(rewards) |
|
|
|
|
|
elif reward_model_name == 'UnifiedRewardModel': |
|
|
rewards, successes_bool = reward_model(images, input_prompts) |
|
|
rewards = [float(reward) if success else 0.0 for reward, success in zip(rewards, successes_bool)] |
|
|
successes = [1 if success else 0 for success in successes_bool] |
|
|
|
|
|
elif reward_model_name == 'PickScoreRewardModel': |
|
|
rewards = reward_model(images, input_prompts) |
|
|
successes = [1] * len(rewards) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown reward model: {reward_model_name}") |
|
|
|
|
|
|
|
|
assert len(rewards) == len(input_prompts), \ |
|
|
f"Length mismatch in {reward_model_name}: rewards ({len(rewards)}) != input_prompts ({len(input_prompts)})" |
|
|
assert len(successes) == len(input_prompts), \ |
|
|
f"Length mismatch in {reward_model_name}: successes ({len(successes)}) != input_prompts ({len(input_prompts)})" |
|
|
|
|
|
return rewards, successes |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"Error in _compute_single_reward with {reward_model_name}: {e}") from e |
|
|
|
|
|
def compute_reward(images, input_prompts, reward_models, reward_weights): |
|
|
assert ( |
|
|
len(images) == len(input_prompts) |
|
|
), f"length of `images` ({len(images)}) must be equal to length of `input_prompts` ({len(input_prompts)})" |
|
|
|
|
|
|
|
|
rewards_dict = {} |
|
|
successes_dict = {} |
|
|
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(reward_models)) as executor: |
|
|
|
|
|
future_to_model = { |
|
|
executor.submit(_compute_single_reward, reward_model, images, input_prompts): reward_model |
|
|
for reward_model in reward_models |
|
|
} |
|
|
|
|
|
|
|
|
for future in concurrent.futures.as_completed(future_to_model): |
|
|
reward_model = future_to_model[future] |
|
|
model_name = type(reward_model).__name__ |
|
|
try: |
|
|
model_rewards, model_successes = future.result() |
|
|
rewards_dict[model_name] = model_rewards |
|
|
successes_dict[model_name] = model_successes |
|
|
except Exception as e: |
|
|
print(f"Error computing reward with {model_name}: {e}") |
|
|
rewards_dict[model_name] = [0.0] * len(input_prompts) |
|
|
successes_dict[model_name] = [0] * len(input_prompts) |
|
|
continue |
|
|
|
|
|
|
|
|
merged_rewards = [0.0] * len(input_prompts) |
|
|
merged_successes = [0] * len(input_prompts) |
|
|
|
|
|
|
|
|
for i in range(len(merged_rewards)): |
|
|
all_success = True |
|
|
for model_name in reward_weights.keys(): |
|
|
if model_name in successes_dict and successes_dict[model_name][i] != 1: |
|
|
all_success = False |
|
|
break |
|
|
|
|
|
if all_success: |
|
|
|
|
|
for model_name, weight in reward_weights.items(): |
|
|
if model_name in rewards_dict: |
|
|
merged_rewards[i] += rewards_dict[model_name][i] * weight |
|
|
merged_successes[i] = 1 |
|
|
|
|
|
return merged_rewards, merged_successes, rewards_dict, successes_dict |
|
|
|
|
|
def balance_pos_neg(samples, use_random=False): |
|
|
"""Balance positive and negative samples distribution in the samples list.""" |
|
|
if use_random: |
|
|
return random.sample(samples, len(samples)) |
|
|
else: |
|
|
positive_samples = [sample for sample in samples if sample['advantages'].item() > 0] |
|
|
negative_samples = [sample for sample in samples if sample['advantages'].item() < 0] |
|
|
|
|
|
positive_samples = random.sample(positive_samples, len(positive_samples)) |
|
|
negative_samples = random.sample(negative_samples, len(negative_samples)) |
|
|
|
|
|
num_positive = len(positive_samples) |
|
|
num_negative = len(negative_samples) |
|
|
|
|
|
balanced_samples = [] |
|
|
|
|
|
if num_positive < num_negative: |
|
|
smaller_group = positive_samples |
|
|
larger_group = negative_samples |
|
|
else: |
|
|
smaller_group = negative_samples |
|
|
larger_group = positive_samples |
|
|
|
|
|
for i in range(len(smaller_group)): |
|
|
balanced_samples.append(smaller_group[i]) |
|
|
balanced_samples.append(larger_group[i]) |
|
|
|
|
|
|
|
|
remaining_samples = larger_group[len(smaller_group):] |
|
|
balanced_samples.extend(remaining_samples) |
|
|
return balanced_samples |
|
|
|
|
|
|