File size: 14,733 Bytes
cc0721b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 | import argparse
import os
import torch
from PIL import Image
from accelerate import Accelerator
# Ensure this path is correct and the utility is available.
from datasets import load_dataset
from torch.distributed import all_gather_object
from transformers import AutoProcessor, AutoConfig, AutoTokenizer, LlavaOnevisionForConditionalGeneration
from trl.models import unwrap_model_for_generation
from data_utils.chart.evaluator import eval_one_chart
from data_utils.rl_prompt import PROMPT_TEMPLATE
from reward_utils.compute_rewards import split_initial_context
accelerator = Accelerator()
from tqdm import tqdm
import numpy as np
DEVICE = accelerator.device
# Model and Processor Configuration
model_args = {} # Use {"torch_dtype": torch.bfloat16} if desired and supported
_eval_parser = argparse.ArgumentParser(add_help=False)
_eval_parser.add_argument("--model_path", default=None)
_eval_args, _ = _eval_parser.parse_known_args()
model_id = (
_eval_args.model_path
or os.environ.get("CHECKPOINT_DIR")
or "/path/to/dyme-k-8/final_checkpoint"
)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, config=config, trust_remote_code=True)
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(DEVICE)
model.eval()
# Make sure model and processor are loaded before being potentially used in generate_inner if it were called
# model = Idefics3ForConditionalGeneration.from_pretrained(model_id, **model_args).to(DEVICE)
processor = AutoProcessor.from_pretrained(model_id)
# Configure image processor size
# This can consume significant VRAM. Ensure it's intended.
if hasattr(processor.image_processor, 'size') and isinstance(processor.image_processor.size, dict):
# if 'longest_edge' in processor.image_processor.size:
# print('Setting image processor longest_edge to 2048')
# processor.image_processor.size['longest_edge'] = 512 * 4
processor.tokenizer.padding_side = 'left'
else:
print(
f"Warning: Could not directly set 'longest_edge' via dict. Current image processor size config: {processor.image_processor.size}"
)
# Attempt an alternative if applicable, e.g.
# processor.image_processor.size = {"longest_edge": 512 * 4} # if size itself can be replaced
# Or this might indicate that `size` is a single value or a different structure.
def run_kh_batch(batch_data_list): # Renamed from run_kh, takes a batch
batch_images = []
batch_formatted_prompts_for_chat_template = []
for item in batch_data_list:
image_path = item['image_path']
# 'item_model_input_text' already contains chart instructions + raw_question
item_model_input_text = item['model_input_text'].strip()
# question_with_tags = prompt + item_model_input_text
# question_with_tags = f"""{item_model_input_text} Think step by step and then answer the question."""
question_with_tags = PROMPT_TEMPLATE.format(question=item_model_input_text)
if isinstance(image_path, str):
image = Image.open(image_path).convert("RGB")
else:
image = image_path.convert("RGB") # Assuming image_path is already a PIL Image object
batch_images.append(image)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question_with_tags},
]
},
]
try:
templated_prompt_str = processor.apply_chat_template(messages, add_generation_prompt=True)
templated_prompt_str = templated_prompt_str.strip()
except Exception:
templated_prompt_str = f"USER: <image>\n{question_with_tags}\nASSISTANT:"
batch_formatted_prompts_for_chat_template.append(templated_prompt_str)
inputs = processor(
text=batch_formatted_prompts_for_chat_template,
images=batch_images,
return_tensors="pt",
padding=True,
truncation=True
)
# inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
inputs = {
k: v.to(DEVICE).to(torch.bfloat16) if v.is_floating_point() else v.to(DEVICE)
for k, v in inputs.items()
}
with unwrap_model_for_generation(model, accelerator) as unwrapped_model_instance:
generated_ids = unwrapped_model_instance.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
)
input_ids_length = inputs['input_ids'].shape[1]
newly_generated_ids = generated_ids[:, input_ids_length:]
generated_texts = processor.batch_decode(
newly_generated_ids,
skip_special_tokens=True, # Special tokens like <eos> are removed. <image> might be too.
)
return [text.strip('.').strip() for text in generated_texts]
# --- Main Evaluation Logic ---
task = 'chart'
# dt_record_local is initialized inside the if task == 'chart' block
if task == 'chart':
dt_record_local = {} # Store results for the current process
if accelerator.is_main_process:
print("Loading ChartQA dataset...")
try:
full_dataset = load_dataset("HuggingFaceM4/ChartQA", trust_remote_code=True)['test']
except Exception as e:
if accelerator.is_main_process:
print(f"Failed to load dataset directly. Error: {e}")
print("Attempting to load with specific revision if applicable, or check path/connection.")
# For example, you can try a specific revision (if known) or ensure path and network connection are correct
# full_dataset = load_dataset("HuggingFaceM4/ChartQA", revision="main", trust_remote_code=True)['test']
raise # Re-raise the exception since we cannot proceed without the dataset
# full_dataset = full_dataset.select(range(80)) # Uncomment for quick testing
eval_datasets_all_prepared = []
# chart_instructions_prefix = (
# "For the question below, follow the following instructions:\n"
# # ... (your detailed instructions) ...
# + "-Try to include the full label from the graph when asked about an entity.\n"
# + "Question: "
# )
for d_item in tqdm(full_dataset, desc="Preparing dataset", disable=not accelerator.is_main_process):
image_path = d_item['image']
raw_question = d_item['query']
answer_list = d_item.get('label') # Use .get() in case 'label' field does not exist
if not answer_list: # If 'label' is missing or an empty list
if accelerator.is_main_process:
tqdm.write(
f"Warning: Item missing 'label' or 'label' is empty. Query: {raw_question[:50]}..."
)
# Decide how to handle this: skip this sample or use a default answer
continue # Skip this sample
answer = answer_list[0]
model_input_text_for_template = raw_question
eval_datasets_all_prepared.append({
'image_path': image_path,
'model_input_text': model_input_text_for_template,
'answer': answer,
'original_question': raw_question
})
num_processes = accelerator.num_processes
process_index = accelerator.process_index
total_items = len(eval_datasets_all_prepared)
if total_items == 0:
if accelerator.is_main_process:
print("No data prepared for evaluation after filtering. Exiting chart evaluation.")
else:
items_per_proc = total_items // num_processes
extra_items = total_items % num_processes
local_start_index = process_index * items_per_proc + min(process_index, extra_items)
num_local_items = items_per_proc + (1 if process_index < extra_items else 0)
local_end_index = local_start_index + num_local_items
eval_datasets_local = eval_datasets_all_prepared[local_start_index:local_end_index]
BATCH_SIZE = 32 # Adjust according to your VRAM
REPORT_INTERVAL_BATCHES = 1 # Report once every N local batches (main process prints global stats)
# if accelerator.is_main_process:
# print(f"Total items for evaluation: {total_items}")
# print(f"Process {process_index} handling {len(eval_datasets_local)} items.")
# print(f"Batch size per process: {BATCH_SIZE}, Reporting interval: {REPORT_INTERVAL_BATCHES} local batches.")
pbar = None
if accelerator.is_main_process and len(eval_datasets_local) > 0: # Create pbar only if there is data
pbar = tqdm(total=len(eval_datasets_local), desc=f"Eval Proc {process_index}", dynamic_ncols=True)
dt_record_local['res'] = []
num_local_batches = (len(eval_datasets_local) + BATCH_SIZE - 1) // BATCH_SIZE
for batch_idx_local in range(num_local_batches):
start_idx = batch_idx_local * BATCH_SIZE
end_idx = min((batch_idx_local + 1) * BATCH_SIZE, len(eval_datasets_local))
current_batch_list = eval_datasets_local[start_idx:end_idx]
if not current_batch_list:
continue
batch_predictions_texts = run_kh_batch(current_batch_list)
for item_idx_in_batch, full_pred_text in enumerate(batch_predictions_texts):
original_item = current_batch_list[item_idx_in_batch]
ground_truth_answer = original_item['answer']
_, parsed_pred_answer = split_initial_context(full_pred_text)
if not parsed_pred_answer.strip():
parsed_pred_answer = full_pred_text # Fallback to full prediction if parsed answer is empty
score = eval_one_chart(parsed_pred_answer, ground_truth_answer) # nlp object is global
dt_record_local['res'].append(score)
# (Optional) Main process prints a few prediction details
if accelerator.is_main_process:
print(full_pred_text, "######", ground_truth_answer, "######", score)
if pbar:
pbar.update(len(current_batch_list))
# --- Intermediate reporting logic ---
is_last_local_batch = (batch_idx_local == num_local_batches - 1)
# Every REPORT_INTERVAL_BATCHES local batches, or on the last local batch of this process,
# perform synchronization and reporting
should_sync_and_report = ((batch_idx_local + 1) % REPORT_INTERVAL_BATCHES == 0) or is_last_local_batch
# Make sure that even if REPORT_INTERVAL_BATCHES is 1, we do not report when there is no data
# (e.g., len(eval_datasets_local) == 0)
if len(eval_datasets_local) == 0: # If the current process has no data, skip reporting logic
should_sync_and_report = False
# If num_local_batches > 0, this check ensures we report only when there is data
if num_local_batches == 0 and is_last_local_batch: # Special case: process has no data but must join final sync
should_sync_and_report = True
if should_sync_and_report:
accelerator.wait_for_everyone() # Wait for all processes to reach the sync point
gathered_all_processes_data = [None] * num_processes
# Each process sends its *current accumulated* dt_record_local
# If a process has no data, dt_record_local['res'] is an empty list, which is fine
all_gather_object(gathered_all_processes_data, dt_record_local)
if accelerator.is_main_process:
current_global_scores_list = []
for process_data_dict in gathered_all_processes_data:
if process_data_dict and 'res' in process_data_dict:
current_global_scores_list.extend(process_data_dict['res'])
total_samples_processed_globally = len(current_global_scores_list)
report_title = "--- Intermediate Report ---"
# Check whether this is the final reporting point where all processes have finished
# A simple heuristic: if this is the last local batch on the main process
# and the total collected samples equal the total number of items
if is_last_local_batch and total_samples_processed_globally == total_items:
report_title = "--- Final Report ---"
elif is_last_local_batch: # Last batch on main process but perhaps not all samples are done yet
report_title = (
f"--- Report (Main Proc Last Batch, {batch_idx_local + 1}/{num_local_batches}) ---"
)
tqdm.write(f"\n{report_title}") # Use tqdm.write to avoid clashing with the progress bar
if current_global_scores_list:
mean_acc_global = np.array(current_global_scores_list).mean()
if accelerator.is_main_process:
print(f"Global samples processed: {total_samples_processed_globally} / {total_items}")
print(f"Current Global Mean Accuracy: {mean_acc_global:.4f}")
if pbar:
pbar.set_description(
f"Global Acc: {mean_acc_global:.4f} ({total_samples_processed_globally}/{total_items})"
)
else:
if accelerator.is_main_process:
print(
f"No scores to report globally yet (Total processed: {total_samples_processed_globally})."
)
accelerator.wait_for_everyone() # Sync again after reporting in case some processes move ahead faster
if pbar:
pbar.close()
# Final metrics have already been printed in the last report
# (when is_last_local_batch is True)
if accelerator.is_main_process and len(eval_datasets_local) == 0 and total_items > 0:
print(
"Main process had no data, but other processes might have. "
"Final global metrics are printed by the last reporting sync."
)
elif accelerator.is_main_process and total_items == 0:
print("No data was prepared for evaluation. Nothing to report.")
else:
if accelerator.is_main_process:
print(f"Task '{task}' is not configured for batched evaluation in this script.")
|