File size: 7,728 Bytes
9b57ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import json
import torch
import multiprocessing as mp
from tqdm import tqdm
from hpsv3.inference import HPSv3RewardInferencer
from multiprocessing import Process, Queue
import math
import fire
import prettytable
def calc_rank_acc(score_sample, predict_sample):
tol_cnt = 0.
true_cnt = 0.
for idx in range(len(score_sample)):
item_base = score_sample[idx]["ranking"]
item = predict_sample[idx]["rewards"]
for i in range(len(item_base)):
for j in range(i+1, len(item_base)):
if item_base[i] > item_base[j]:
if item[i] >= item[j]:
tol_cnt += 1
elif item[i] < item[j]:
tol_cnt += 1
true_cnt += 1
elif item_base[i] < item_base[j]:
if item[i] > item[j]:
tol_cnt += 1
true_cnt += 1
elif item[i] <= item[j]:
tol_cnt += 1
return true_cnt / tol_cnt
def worker_process(process_id, data_chunk, config_path, checkpoint_path, batch_size, result_queue, mode):
"""
Worker function for each process to handle a chunk of data
"""
# Each process uses a different GPU (cycle through available GPUs)
num_gpus = torch.cuda.device_count()
device = f"cuda:{process_id % num_gpus}" if num_gpus > 0 else "cpu"
dtype = torch.bfloat16
print(f"Process {process_id} starting with device {device}, processing {len(data_chunk)} items")
# Initialize model for this process
inferencer = HPSv3RewardInferencer(config_path, checkpoint_path, device=device, dtype=dtype)
process_correct = 0
process_equal = 0
process_results = []
for batch_start in tqdm(range(0, len(data_chunk), batch_size),
total=(len(data_chunk) + batch_size - 1) // batch_size,
desc=f"Process {process_id}"):
batch_end = min(batch_start + batch_size, len(data_chunk))
batch_info = data_chunk[batch_start:batch_end]
if mode == 'pair':
image_paths_1 = [info["path1"] for info in batch_info]
image_paths_2 = [info["path2"] for info in batch_info]
prompts = [info["prompt"] for info in batch_info]
with torch.no_grad():
rewards_1 = inferencer.reward(image_paths_1, prompts)
rewards_2 = inferencer.reward(image_paths_2, prompts)
for i in range(len(batch_info)):
info = batch_info[i]
if rewards_1.ndim == 2:
reward_1, reward_2 = rewards_1[i][0].item(), rewards_2[i][0].item()
else:
reward_1, reward_2 = rewards_1[i].item(), rewards_2[i].item()
item_result = {
'reward_1': reward_1,
'reward_2': reward_2,
'correct': reward_1 > reward_2,
'equal': reward_1 == reward_2,
'info': info
}
process_results.append(item_result)
print(f"Process {process_id} - Reward 1: {reward_1}, Reward 2: {reward_2}")
if reward_1 > reward_2:
process_correct += 1
if reward_1 == reward_2:
process_equal += 1
elif mode == 'ranking':
for item in batch_info:
rewards = inferencer.reward(item["generations"], item["prompt"])
predict_item = {
"id": item["id"],
"prompt": item["prompt"],
"rewards": rewards
}
process_results.append(predict_item)
# Put results in queue
if mode == 'pair':
result_queue.put({
'process_id': process_id,
'correct': process_correct,
'equal': process_equal,
'total': len(data_chunk),
'results': process_results
})
elif mode == 'ranking':
result_queue.put({
'process_id': process_id,
'results': process_results
})
print(f"Process {process_id} completed: {process_correct}/{len(data_chunk)} correct, {process_equal}/{len(data_chunk)} equal")
def main(test_json, config_path=None, batch_size=8, num_processes=8, checkpoint_path=None, mode='pair'):
assert mode in ['pair', 'ranking'], "Mode must be either 'pair' or 'ranking'"
assert checkpoint_path is not None, "Checkpoint path must be provided for inference"
mp.set_start_method('spawn', force=True)
info_list = json.load(open(test_json, "r"))
print(f"Total items to process: {len(info_list)}")
# Split data into chunks for each process
chunk_size = math.ceil(len(info_list) / num_processes)
data_chunks = []
for i in range(num_processes):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(info_list))
if start_idx < len(info_list):
chunk = info_list[start_idx:end_idx]
data_chunks.append(chunk)
print(f"Process {i}: {len(chunk)} items (indices {start_idx}-{end_idx-1})")
# Ensure we have the right number of non-empty chunks
actual_processes = len(data_chunks)
print(f"Using {actual_processes} processes")
# Create result queue and processes
result_queue = Queue()
processes = []
print("Starting processes...")
for i in range(actual_processes):
p = Process(target=worker_process, args=(i, data_chunks[i], config_path, checkpoint_path, batch_size, result_queue, mode))
p.start()
processes.append(p)
# Wait for all processes to complete and collect results
all_results = []
total_correct = 0
total_equal = 0
total_items = 0
print("Waiting for processes to complete...")
for i in range(actual_processes):
result = result_queue.get()
all_results.append(result)
if mode == 'pair':
total_correct += result['correct']
total_equal += result['equal']
total_items += result['total']
print(f"Process {result['process_id']} finished: {result['correct']}/{result['total']} correct, {result['equal']}/{result['total']} equal")
# Wait for all processes to join
for p in processes:
p.join()
if mode == 'pair':
aggregated_results = {
'total_correct': total_correct,
'total_equal': total_equal,
'total_items': total_items,
'accuracy': total_correct / total_items,
'process_results': all_results
}
table = prettytable.PrettyTable()
table.field_names = ["Total Items", "Correct", "Equal", "Incorrect", "Accuracy (%)"]
incorrect = aggregated_results['total_items'] - aggregated_results['total_correct'] - aggregated_results['total_equal']
accuracy_percent = 100 * aggregated_results['total_correct'] / aggregated_results['total_items']
table.add_row([
aggregated_results['total_items'],
aggregated_results['total_correct'],
aggregated_results['total_equal'],
incorrect,
f"{accuracy_percent:.2f}"
])
elif mode == 'ranking':
rank_acc = calc_rank_acc(info_list, all_results[0]['results'])
table = prettytable.PrettyTable()
table.field_names = ["Total Items", "Rank Accuracy (%)"]
table.add_row([len(info_list), f"{rank_acc * 100:.2f}"])
print(table)
if __name__ == "__main__":
fire.Fire(main) |