|
|
|
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import multiprocessing as mp |
|
|
from tqdm import tqdm |
|
|
from hpsv3.inference import HPSv3RewardInferencer |
|
|
from multiprocessing import Process, Queue |
|
|
import math |
|
|
import fire |
|
|
import prettytable |
|
|
|
|
|
def calc_rank_acc(score_sample, predict_sample): |
|
|
tol_cnt = 0. |
|
|
true_cnt = 0. |
|
|
for idx in range(len(score_sample)): |
|
|
item_base = score_sample[idx]["ranking"] |
|
|
item = predict_sample[idx]["rewards"] |
|
|
for i in range(len(item_base)): |
|
|
for j in range(i+1, len(item_base)): |
|
|
if item_base[i] > item_base[j]: |
|
|
if item[i] >= item[j]: |
|
|
tol_cnt += 1 |
|
|
elif item[i] < item[j]: |
|
|
tol_cnt += 1 |
|
|
true_cnt += 1 |
|
|
elif item_base[i] < item_base[j]: |
|
|
if item[i] > item[j]: |
|
|
tol_cnt += 1 |
|
|
true_cnt += 1 |
|
|
elif item[i] <= item[j]: |
|
|
tol_cnt += 1 |
|
|
return true_cnt / tol_cnt |
|
|
|
|
|
|
|
|
def worker_process(process_id, data_chunk, config_path, checkpoint_path, batch_size, result_queue, mode): |
|
|
""" |
|
|
Worker function for each process to handle a chunk of data |
|
|
""" |
|
|
|
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
|
device = f"cuda:{process_id % num_gpus}" if num_gpus > 0 else "cpu" |
|
|
dtype = torch.bfloat16 |
|
|
|
|
|
print(f"Process {process_id} starting with device {device}, processing {len(data_chunk)} items") |
|
|
|
|
|
|
|
|
inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device, dtype=dtype) |
|
|
|
|
|
process_correct = 0 |
|
|
process_equal = 0 |
|
|
process_results = [] |
|
|
|
|
|
for batch_start in tqdm(range(0, len(data_chunk), batch_size), |
|
|
total=(len(data_chunk) + batch_size - 1) // batch_size, |
|
|
desc=f"Process {process_id}"): |
|
|
batch_end = min(batch_start + batch_size, len(data_chunk)) |
|
|
batch_info = data_chunk[batch_start:batch_end] |
|
|
if mode == 'pair': |
|
|
image_paths_1 = [info["path1"] for info in batch_info] |
|
|
image_paths_2 = [info["path2"] for info in batch_info] |
|
|
prompts = [info["prompt"] for info in batch_info] |
|
|
|
|
|
with torch.no_grad(): |
|
|
rewards_1 = inferencer.reward(image_paths_1, prompts) |
|
|
rewards_2 = inferencer.reward(image_paths_2, prompts) |
|
|
|
|
|
for i in range(len(batch_info)): |
|
|
info = batch_info[i] |
|
|
if rewards_1.ndim == 2: |
|
|
reward_1, reward_2 = rewards_1[i][0].item(), rewards_2[i][0].item() |
|
|
else: |
|
|
reward_1, reward_2 = rewards_1[i].item(), rewards_2[i].item() |
|
|
|
|
|
item_result = { |
|
|
'reward_1': reward_1, |
|
|
'reward_2': reward_2, |
|
|
'correct': reward_1 > reward_2, |
|
|
'equal': reward_1 == reward_2, |
|
|
'info': info |
|
|
} |
|
|
process_results.append(item_result) |
|
|
|
|
|
print(f"Process {process_id} - Reward 1: {reward_1}, Reward 2: {reward_2}") |
|
|
if reward_1 > reward_2: |
|
|
process_correct += 1 |
|
|
if reward_1 == reward_2: |
|
|
process_equal += 1 |
|
|
|
|
|
elif mode == 'ranking': |
|
|
for item in batch_info: |
|
|
rewards = inferencer.reward(item["generations"], item["prompt"]) |
|
|
predict_item = { |
|
|
"id": item["id"], |
|
|
"prompt": item["prompt"], |
|
|
"rewards": rewards |
|
|
} |
|
|
process_results.append(predict_item) |
|
|
|
|
|
if mode == 'pair': |
|
|
result_queue.put({ |
|
|
'process_id': process_id, |
|
|
'correct': process_correct, |
|
|
'equal': process_equal, |
|
|
'total': len(data_chunk), |
|
|
'results': process_results |
|
|
}) |
|
|
elif mode == 'ranking': |
|
|
result_queue.put({ |
|
|
'process_id': process_id, |
|
|
'results': process_results |
|
|
}) |
|
|
|
|
|
print(f"Process {process_id} completed: {process_correct}/{len(data_chunk)} correct, {process_equal}/{len(data_chunk)} equal") |
|
|
|
|
|
def main(test_json, config_path=None, batch_size=8, num_processes=8, checkpoint_path=None, mode='pair'): |
|
|
|
|
|
assert mode in ['pair', 'ranking'], "Mode must be either 'pair' or 'ranking'" |
|
|
assert checkpoint_path is not None, "Checkpoint path must be provided for inference" |
|
|
|
|
|
mp.set_start_method('spawn', force=True) |
|
|
|
|
|
info_list = json.load(open(test_json, "r")) |
|
|
|
|
|
print(f"Total items to process: {len(info_list)}") |
|
|
|
|
|
chunk_size = math.ceil(len(info_list) / num_processes) |
|
|
data_chunks = [] |
|
|
for i in range(num_processes): |
|
|
start_idx = i * chunk_size |
|
|
end_idx = min((i + 1) * chunk_size, len(info_list)) |
|
|
if start_idx < len(info_list): |
|
|
chunk = info_list[start_idx:end_idx] |
|
|
data_chunks.append(chunk) |
|
|
print(f"Process {i}: {len(chunk)} items (indices {start_idx}-{end_idx-1})") |
|
|
|
|
|
|
|
|
actual_processes = len(data_chunks) |
|
|
print(f"Using {actual_processes} processes") |
|
|
|
|
|
|
|
|
result_queue = Queue() |
|
|
processes = [] |
|
|
|
|
|
print("Starting processes...") |
|
|
for i in range(actual_processes): |
|
|
p = Process(target=worker_process, args=(i, data_chunks[i], config_path, checkpoint_path, batch_size, result_queue, mode)) |
|
|
p.start() |
|
|
processes.append(p) |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
total_correct = 0 |
|
|
total_equal = 0 |
|
|
total_items = 0 |
|
|
|
|
|
print("Waiting for processes to complete...") |
|
|
for i in range(actual_processes): |
|
|
result = result_queue.get() |
|
|
all_results.append(result) |
|
|
if mode == 'pair': |
|
|
total_correct += result['correct'] |
|
|
total_equal += result['equal'] |
|
|
total_items += result['total'] |
|
|
|
|
|
print(f"Process {result['process_id']} finished: {result['correct']}/{result['total']} correct, {result['equal']}/{result['total']} equal") |
|
|
|
|
|
|
|
|
for p in processes: |
|
|
p.join() |
|
|
|
|
|
if mode == 'pair': |
|
|
aggregated_results = { |
|
|
'total_correct': total_correct, |
|
|
'total_equal': total_equal, |
|
|
'total_items': total_items, |
|
|
'accuracy': total_correct / total_items, |
|
|
'process_results': all_results |
|
|
} |
|
|
table = prettytable.PrettyTable() |
|
|
table.field_names = ["Total Items", "Correct", "Equal", "Incorrect", "Accuracy (%)"] |
|
|
|
|
|
incorrect = aggregated_results['total_items'] - aggregated_results['total_correct'] - aggregated_results['total_equal'] |
|
|
accuracy_percent = 100 * aggregated_results['total_correct'] / aggregated_results['total_items'] |
|
|
|
|
|
table.add_row([ |
|
|
aggregated_results['total_items'], |
|
|
aggregated_results['total_correct'], |
|
|
aggregated_results['total_equal'], |
|
|
incorrect, |
|
|
f"{accuracy_percent:.2f}" |
|
|
]) |
|
|
elif mode == 'ranking': |
|
|
rank_acc = calc_rank_acc(info_list, all_results[0]['results']) |
|
|
table = prettytable.PrettyTable() |
|
|
table.field_names = ["Total Items", "Rank Accuracy (%)"] |
|
|
table.add_row([len(info_list), f"{rank_acc * 100:.2f}"]) |
|
|
|
|
|
print(table) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
fire.Fire(main) |