studyOverflow's picture
Add files using upload-large-folder tool
b171568 verified
import concurrent.futures
import random
def _compute_single_reward(reward_model, images, input_prompts):
"""Compute reward for a single reward model."""
reward_model_name = type(reward_model).__name__
try:
if reward_model_name == 'HPSClipRewardModel':
rewards = reward_model(images, input_prompts)
successes = [1] * len(rewards)
elif reward_model_name == 'CLIPScoreRewardModel':
rewards = reward_model(input_prompts, images)
successes = [1] * len(rewards)
elif reward_model_name == 'ImageRewardModel':
rewards = reward_model(images, input_prompts)
successes = [1] * len(rewards)
elif reward_model_name == 'UnifiedRewardModel':
rewards, successes_bool = reward_model(images, input_prompts)
rewards = [float(reward) if success else 0.0 for reward, success in zip(rewards, successes_bool)]
successes = [1 if success else 0 for success in successes_bool]
elif reward_model_name == 'PickScoreRewardModel':
rewards = reward_model(images, input_prompts)
successes = [1] * len(rewards)
else:
raise ValueError(f"Unknown reward model: {reward_model_name}")
# Verify the length of results matches input
assert len(rewards) == len(input_prompts), \
f"Length mismatch in {reward_model_name}: rewards ({len(rewards)}) != input_prompts ({len(input_prompts)})"
assert len(successes) == len(input_prompts), \
f"Length mismatch in {reward_model_name}: successes ({len(successes)}) != input_prompts ({len(input_prompts)})"
return rewards, successes
except Exception as e:
raise ValueError(f"Error in _compute_single_reward with {reward_model_name}: {e}") from e
def compute_reward(images, input_prompts, reward_models, reward_weights):
assert (
len(images) == len(input_prompts)
), f"length of `images` ({len(images)}) must be equal to length of `input_prompts` ({len(input_prompts)})"
# Initialize results
rewards_dict = {}
successes_dict = {}
# Create a thread pool for parallel reward computation
with concurrent.futures.ThreadPoolExecutor(max_workers=len(reward_models)) as executor:
# Submit all reward computation tasks
future_to_model = {
executor.submit(_compute_single_reward, reward_model, images, input_prompts): reward_model
for reward_model in reward_models
}
# Process results as they complete
for future in concurrent.futures.as_completed(future_to_model):
reward_model = future_to_model[future]
model_name = type(reward_model).__name__
try:
model_rewards, model_successes = future.result()
rewards_dict[model_name] = model_rewards
successes_dict[model_name] = model_successes
except Exception as e:
print(f"Error computing reward with {model_name}: {e}")
rewards_dict[model_name] = [0.0] * len(input_prompts)
successes_dict[model_name] = [0] * len(input_prompts)
continue
# Merge rewards based on weights
merged_rewards = [0.0] * len(input_prompts)
merged_successes = [0] * len(input_prompts)
# First check if all models are successful for each sample
for i in range(len(merged_rewards)):
all_success = True
for model_name in reward_weights.keys():
if model_name in successes_dict and successes_dict[model_name][i] != 1:
all_success = False
break
if all_success:
# Only compute weighted sum if all models are successful
for model_name, weight in reward_weights.items():
if model_name in rewards_dict:
merged_rewards[i] += rewards_dict[model_name][i] * weight
merged_successes[i] = 1
return merged_rewards, merged_successes, rewards_dict, successes_dict
def balance_pos_neg(samples, use_random=False):
"""Balance positive and negative samples distribution in the samples list."""
if use_random:
return random.sample(samples, len(samples))
else:
positive_samples = [sample for sample in samples if sample['advantages'].item() > 0]
negative_samples = [sample for sample in samples if sample['advantages'].item() < 0]
positive_samples = random.sample(positive_samples, len(positive_samples))
negative_samples = random.sample(negative_samples, len(negative_samples))
num_positive = len(positive_samples)
num_negative = len(negative_samples)
balanced_samples = []
if num_positive < num_negative:
smaller_group = positive_samples
larger_group = negative_samples
else:
smaller_group = negative_samples
larger_group = positive_samples
for i in range(len(smaller_group)):
balanced_samples.append(smaller_group[i])
balanced_samples.append(larger_group[i])
# If there are remaining samples in the larger group, add them
remaining_samples = larger_group[len(smaller_group):]
balanced_samples.extend(remaining_samples)
return balanced_samples