agentic-rl-main / eval /eval_chartqa.py
Jack04810's picture
Add files using upload-large-folder tool
cc0721b verified
Raw
History Blame Contribute Delete
14.7 kB
import argparse
import os
import torch
from PIL import Image
from accelerate import Accelerator
# Ensure this path is correct and the utility is available.
from datasets import load_dataset
from torch.distributed import all_gather_object
from transformers import AutoProcessor, AutoConfig, AutoTokenizer, LlavaOnevisionForConditionalGeneration
from trl.models import unwrap_model_for_generation
from data_utils.chart.evaluator import eval_one_chart
from data_utils.rl_prompt import PROMPT_TEMPLATE
from reward_utils.compute_rewards import split_initial_context
accelerator = Accelerator()
from tqdm import tqdm
import numpy as np
DEVICE = accelerator.device
# Model and Processor Configuration
model_args = {} # Use {"torch_dtype": torch.bfloat16} if desired and supported
_eval_parser = argparse.ArgumentParser(add_help=False)
_eval_parser.add_argument("--model_path", default=None)
_eval_args, _ = _eval_parser.parse_known_args()
model_id = (
_eval_args.model_path
or os.environ.get("CHECKPOINT_DIR")
or "/path/to/dyme-k-8/final_checkpoint"
)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, config=config, trust_remote_code=True)
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(DEVICE)
model.eval()
# Make sure model and processor are loaded before being potentially used in generate_inner if it were called
# model = Idefics3ForConditionalGeneration.from_pretrained(model_id, **model_args).to(DEVICE)
processor = AutoProcessor.from_pretrained(model_id)
# Configure image processor size
# This can consume significant VRAM. Ensure it's intended.
if hasattr(processor.image_processor, 'size') and isinstance(processor.image_processor.size, dict):
# if 'longest_edge' in processor.image_processor.size:
# print('Setting image processor longest_edge to 2048')
# processor.image_processor.size['longest_edge'] = 512 * 4
processor.tokenizer.padding_side = 'left'
else:
print(
f"Warning: Could not directly set 'longest_edge' via dict. Current image processor size config: {processor.image_processor.size}"
)
# Attempt an alternative if applicable, e.g.
# processor.image_processor.size = {"longest_edge": 512 * 4} # if size itself can be replaced
# Or this might indicate that `size` is a single value or a different structure.
def run_kh_batch(batch_data_list): # Renamed from run_kh, takes a batch
batch_images = []
batch_formatted_prompts_for_chat_template = []
for item in batch_data_list:
image_path = item['image_path']
# 'item_model_input_text' already contains chart instructions + raw_question
item_model_input_text = item['model_input_text'].strip()
# question_with_tags = prompt + item_model_input_text
# question_with_tags = f"""{item_model_input_text} Think step by step and then answer the question."""
question_with_tags = PROMPT_TEMPLATE.format(question=item_model_input_text)
if isinstance(image_path, str):
image = Image.open(image_path).convert("RGB")
else:
image = image_path.convert("RGB") # Assuming image_path is already a PIL Image object
batch_images.append(image)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question_with_tags},
]
},
]
try:
templated_prompt_str = processor.apply_chat_template(messages, add_generation_prompt=True)
templated_prompt_str = templated_prompt_str.strip()
except Exception:
templated_prompt_str = f"USER: <image>\n{question_with_tags}\nASSISTANT:"
batch_formatted_prompts_for_chat_template.append(templated_prompt_str)
inputs = processor(
text=batch_formatted_prompts_for_chat_template,
images=batch_images,
return_tensors="pt",
padding=True,
truncation=True
)
# inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
inputs = {
k: v.to(DEVICE).to(torch.bfloat16) if v.is_floating_point() else v.to(DEVICE)
for k, v in inputs.items()
}
with unwrap_model_for_generation(model, accelerator) as unwrapped_model_instance:
generated_ids = unwrapped_model_instance.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
)
input_ids_length = inputs['input_ids'].shape[1]
newly_generated_ids = generated_ids[:, input_ids_length:]
generated_texts = processor.batch_decode(
newly_generated_ids,
skip_special_tokens=True, # Special tokens like <eos> are removed. <image> might be too.
)
return [text.strip('.').strip() for text in generated_texts]
# --- Main Evaluation Logic ---
task = 'chart'
# dt_record_local is initialized inside the if task == 'chart' block
if task == 'chart':
dt_record_local = {} # Store results for the current process
if accelerator.is_main_process:
print("Loading ChartQA dataset...")
try:
full_dataset = load_dataset("HuggingFaceM4/ChartQA", trust_remote_code=True)['test']
except Exception as e:
if accelerator.is_main_process:
print(f"Failed to load dataset directly. Error: {e}")
print("Attempting to load with specific revision if applicable, or check path/connection.")
# For example, you can try a specific revision (if known) or ensure path and network connection are correct
# full_dataset = load_dataset("HuggingFaceM4/ChartQA", revision="main", trust_remote_code=True)['test']
raise # Re-raise the exception since we cannot proceed without the dataset
# full_dataset = full_dataset.select(range(80)) # Uncomment for quick testing
eval_datasets_all_prepared = []
# chart_instructions_prefix = (
# "For the question below, follow the following instructions:\n"
# # ... (your detailed instructions) ...
# + "-Try to include the full label from the graph when asked about an entity.\n"
# + "Question: "
# )
for d_item in tqdm(full_dataset, desc="Preparing dataset", disable=not accelerator.is_main_process):
image_path = d_item['image']
raw_question = d_item['query']
answer_list = d_item.get('label') # Use .get() in case 'label' field does not exist
if not answer_list: # If 'label' is missing or an empty list
if accelerator.is_main_process:
tqdm.write(
f"Warning: Item missing 'label' or 'label' is empty. Query: {raw_question[:50]}..."
)
# Decide how to handle this: skip this sample or use a default answer
continue # Skip this sample
answer = answer_list[0]
model_input_text_for_template = raw_question
eval_datasets_all_prepared.append({
'image_path': image_path,
'model_input_text': model_input_text_for_template,
'answer': answer,
'original_question': raw_question
})
num_processes = accelerator.num_processes
process_index = accelerator.process_index
total_items = len(eval_datasets_all_prepared)
if total_items == 0:
if accelerator.is_main_process:
print("No data prepared for evaluation after filtering. Exiting chart evaluation.")
else:
items_per_proc = total_items // num_processes
extra_items = total_items % num_processes
local_start_index = process_index * items_per_proc + min(process_index, extra_items)
num_local_items = items_per_proc + (1 if process_index < extra_items else 0)
local_end_index = local_start_index + num_local_items
eval_datasets_local = eval_datasets_all_prepared[local_start_index:local_end_index]
BATCH_SIZE = 32 # Adjust according to your VRAM
REPORT_INTERVAL_BATCHES = 1 # Report once every N local batches (main process prints global stats)
# if accelerator.is_main_process:
# print(f"Total items for evaluation: {total_items}")
# print(f"Process {process_index} handling {len(eval_datasets_local)} items.")
# print(f"Batch size per process: {BATCH_SIZE}, Reporting interval: {REPORT_INTERVAL_BATCHES} local batches.")
pbar = None
if accelerator.is_main_process and len(eval_datasets_local) > 0: # Create pbar only if there is data
pbar = tqdm(total=len(eval_datasets_local), desc=f"Eval Proc {process_index}", dynamic_ncols=True)
dt_record_local['res'] = []
num_local_batches = (len(eval_datasets_local) + BATCH_SIZE - 1) // BATCH_SIZE
for batch_idx_local in range(num_local_batches):
start_idx = batch_idx_local * BATCH_SIZE
end_idx = min((batch_idx_local + 1) * BATCH_SIZE, len(eval_datasets_local))
current_batch_list = eval_datasets_local[start_idx:end_idx]
if not current_batch_list:
continue
batch_predictions_texts = run_kh_batch(current_batch_list)
for item_idx_in_batch, full_pred_text in enumerate(batch_predictions_texts):
original_item = current_batch_list[item_idx_in_batch]
ground_truth_answer = original_item['answer']
_, parsed_pred_answer = split_initial_context(full_pred_text)
if not parsed_pred_answer.strip():
parsed_pred_answer = full_pred_text # Fallback to full prediction if parsed answer is empty
score = eval_one_chart(parsed_pred_answer, ground_truth_answer) # nlp object is global
dt_record_local['res'].append(score)
# (Optional) Main process prints a few prediction details
if accelerator.is_main_process:
print(full_pred_text, "######", ground_truth_answer, "######", score)
if pbar:
pbar.update(len(current_batch_list))
# --- Intermediate reporting logic ---
is_last_local_batch = (batch_idx_local == num_local_batches - 1)
# Every REPORT_INTERVAL_BATCHES local batches, or on the last local batch of this process,
# perform synchronization and reporting
should_sync_and_report = ((batch_idx_local + 1) % REPORT_INTERVAL_BATCHES == 0) or is_last_local_batch
# Make sure that even if REPORT_INTERVAL_BATCHES is 1, we do not report when there is no data
# (e.g., len(eval_datasets_local) == 0)
if len(eval_datasets_local) == 0: # If the current process has no data, skip reporting logic
should_sync_and_report = False
# If num_local_batches > 0, this check ensures we report only when there is data
if num_local_batches == 0 and is_last_local_batch: # Special case: process has no data but must join final sync
should_sync_and_report = True
if should_sync_and_report:
accelerator.wait_for_everyone() # Wait for all processes to reach the sync point
gathered_all_processes_data = [None] * num_processes
# Each process sends its *current accumulated* dt_record_local
# If a process has no data, dt_record_local['res'] is an empty list, which is fine
all_gather_object(gathered_all_processes_data, dt_record_local)
if accelerator.is_main_process:
current_global_scores_list = []
for process_data_dict in gathered_all_processes_data:
if process_data_dict and 'res' in process_data_dict:
current_global_scores_list.extend(process_data_dict['res'])
total_samples_processed_globally = len(current_global_scores_list)
report_title = "--- Intermediate Report ---"
# Check whether this is the final reporting point where all processes have finished
# A simple heuristic: if this is the last local batch on the main process
# and the total collected samples equal the total number of items
if is_last_local_batch and total_samples_processed_globally == total_items:
report_title = "--- Final Report ---"
elif is_last_local_batch: # Last batch on main process but perhaps not all samples are done yet
report_title = (
f"--- Report (Main Proc Last Batch, {batch_idx_local + 1}/{num_local_batches}) ---"
)
tqdm.write(f"\n{report_title}") # Use tqdm.write to avoid clashing with the progress bar
if current_global_scores_list:
mean_acc_global = np.array(current_global_scores_list).mean()
if accelerator.is_main_process:
print(f"Global samples processed: {total_samples_processed_globally} / {total_items}")
print(f"Current Global Mean Accuracy: {mean_acc_global:.4f}")
if pbar:
pbar.set_description(
f"Global Acc: {mean_acc_global:.4f} ({total_samples_processed_globally}/{total_items})"
)
else:
if accelerator.is_main_process:
print(
f"No scores to report globally yet (Total processed: {total_samples_processed_globally})."
)
accelerator.wait_for_everyone() # Sync again after reporting in case some processes move ahead faster
if pbar:
pbar.close()
# Final metrics have already been printed in the last report
# (when is_last_local_batch is True)
if accelerator.is_main_process and len(eval_datasets_local) == 0 and total_items > 0:
print(
"Main process had no data, but other processes might have. "
"Final global metrics are printed by the last reporting sync."
)
elif accelerator.is_main_process and total_items == 0:
print("No data was prepared for evaluation. Nothing to report.")
else:
if accelerator.is_main_process:
print(f"Task '{task}' is not configured for batched evaluation in this script.")