|
|
import os |
|
|
import re |
|
|
import json |
|
|
import argparse |
|
|
from typing import List, Dict, Any, Optional |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoModelForCausalLM, AutoProcessor, AutoModelForVision2Seq |
|
|
try: |
|
|
from transformers import AutoModelForImageTextToText |
|
|
except Exception: |
|
|
AutoModelForImageTextToText = None |
|
|
import importlib |
|
|
|
|
|
|
|
|
def extract_boxed_answer(text: str) -> str: |
|
|
"""Extract final answer from model text. |
|
|
|
|
|
Priority: |
|
|
1) <answer>...</answer> tags (backward compatibility) |
|
|
2) last \\boxed{...} with proper brace matching |
|
|
""" |
|
|
try: |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
|
|
|
low = text.lower() |
|
|
s = low.find("<answer>") |
|
|
e = low.find("</answer>") |
|
|
if s != -1 and e != -1 and e > s: |
|
|
return text[s + len("<answer>") : e].strip() |
|
|
|
|
|
|
|
|
boxed_pattern = r"\\boxed\{" |
|
|
matches = list(re.finditer(boxed_pattern, text)) |
|
|
if matches: |
|
|
last_match = matches[-1] |
|
|
start_pos = last_match.end() |
|
|
|
|
|
brace_count = 1 |
|
|
pos = start_pos |
|
|
while pos < len(text) and brace_count > 0: |
|
|
if text[pos] == '{': |
|
|
brace_count += 1 |
|
|
elif text[pos] == '}': |
|
|
brace_count -= 1 |
|
|
pos += 1 |
|
|
|
|
|
if brace_count == 0: |
|
|
return text[start_pos : pos - 1].strip() |
|
|
except Exception: |
|
|
pass |
|
|
return "" |
|
|
|
|
|
|
|
|
def normalize_answer(ans: str) -> str: |
|
|
"""Simple normalization used during training env: lowercase + remove whitespace.""" |
|
|
return re.sub(r"\s+", "", (ans or "").lower().strip()) |
|
|
|
|
|
|
|
|
def setup_distributed(): |
|
|
"""Initialize distributed environment if not already set up.""" |
|
|
if not dist.is_initialized(): |
|
|
|
|
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
|
|
rank = int(os.environ["RANK"]) |
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
else: |
|
|
|
|
|
rank = 0 |
|
|
world_size = 1 |
|
|
local_rank = 0 |
|
|
|
|
|
if world_size > 1: |
|
|
dist.init_process_group(backend="nccl") |
|
|
torch.cuda.set_device(local_rank) |
|
|
|
|
|
return rank, world_size, local_rank |
|
|
else: |
|
|
return dist.get_rank(), dist.get_world_size(), int(os.environ.get("LOCAL_RANK", 0)) |
|
|
|
|
|
|
|
|
def shard_data(data: List[Dict[str, Any]], rank: int, world_size: int) -> List[Dict[str, Any]]: |
|
|
"""Shard data across multiple processes.""" |
|
|
|
|
|
total = len(data) |
|
|
per_rank = (total + world_size - 1) // world_size |
|
|
start_idx = rank * per_rank |
|
|
end_idx = min(start_idx + per_rank, total) |
|
|
return data[start_idx:end_idx] |
|
|
|
|
|
|
|
|
def load_dataset(json_path: str) -> List[Dict[str, Any]]: |
|
|
with open(json_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
if not isinstance(data, list): |
|
|
raise ValueError(f"Expected a JSON array at {json_path}") |
|
|
return data |
|
|
|
|
|
|
|
|
def open_image(image_path: Optional[str]) -> Optional[Image.Image]: |
|
|
if image_path is None: |
|
|
return None |
|
|
if not os.path.exists(image_path): |
|
|
return None |
|
|
try: |
|
|
return Image.open(image_path).convert("RGB") |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def _load_qwen_vl_model(model_id: str, torch_dtype, device_map: str, local_rank: int = 0): |
|
|
"""Load Qwen2.5-VL model across transformers versions. |
|
|
|
|
|
Prefer the specific Qwen2.5-VL class, then AutoModelForImageTextToText, then Vision2Seq. |
|
|
""" |
|
|
|
|
|
if device_map == "auto" and local_rank >= 0: |
|
|
actual_device_map = f"cuda:{local_rank}" |
|
|
elif isinstance(device_map, str) and device_map.startswith("cuda:"): |
|
|
actual_device_map = device_map |
|
|
else: |
|
|
actual_device_map = device_map |
|
|
|
|
|
|
|
|
def _from_pretrained_with_dtype(cls): |
|
|
try: |
|
|
return cls.from_pretrained( |
|
|
model_id, |
|
|
dtype=torch_dtype, |
|
|
device_map=actual_device_map, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
except TypeError: |
|
|
return cls.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map=actual_device_map, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
modeling_module = importlib.import_module("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl") |
|
|
specific_cls = getattr(modeling_module, "Qwen2_5_VLForConditionalGeneration", None) |
|
|
if specific_cls is not None: |
|
|
return _from_pretrained_with_dtype(specific_cls) |
|
|
except Exception as e: |
|
|
print(f"[DEBUG] Failed to load with Qwen2_5_VLForConditionalGeneration: {e}") |
|
|
|
|
|
|
|
|
if AutoModelForImageTextToText is not None: |
|
|
try: |
|
|
return _from_pretrained_with_dtype(AutoModelForImageTextToText) |
|
|
except Exception as e: |
|
|
print(f"[DEBUG] Failed to load with AutoModelForImageTextToText: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
return _from_pretrained_with_dtype(AutoModelForVision2Seq) |
|
|
except Exception as e: |
|
|
print(f"[DEBUG] Failed to load with AutoModelForVision2Seq: {e}") |
|
|
|
|
|
raise RuntimeError(f"Could not load Qwen2.5-VL model from {model_id}. All loading methods failed.") |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate_answer( |
|
|
model, |
|
|
processor, |
|
|
prompt: str, |
|
|
image: Optional[Image.Image], |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.0, |
|
|
top_p: float = 1.0, |
|
|
do_sample: bool = False, |
|
|
) -> str: |
|
|
|
|
|
content: List[Dict[str, Any]] = [] |
|
|
if image is not None: |
|
|
content.append({"type": "image", "image": image}) |
|
|
content.append({"type": "text", "text": prompt}) |
|
|
|
|
|
messages = [{"role": "user", "content": content}] |
|
|
|
|
|
chat_text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
inputs = processor(text=[chat_text], images=[image] if image is not None else None, return_tensors="pt") |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=do_sample, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
|
|
|
gen_tokens = outputs[:, inputs["input_ids"].shape[1]:] |
|
|
text_out = processor.batch_decode(gen_tokens, skip_special_tokens=True)[0] |
|
|
return text_out.strip() |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate Qwen2.5-VL baseline on MM_Math and compute accuracy") |
|
|
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="HF model id/path") |
|
|
parser.add_argument("--data", type=str, default="/root/CVPR/MemGen/data/mm_math/train.json", help="Path to preprocessed split JSON") |
|
|
parser.add_argument("--output_jsonl", type=str, default="/root/CVPR/MemGen/test_output/mm_math/logs/qwen25vl_eval.jsonl", help="Where to save per-sample logs") |
|
|
parser.add_argument("--max_samples", type=int, default=-1, help="Limit number of evaluated samples; -1 for all") |
|
|
parser.add_argument("--device_map", type=str, default="auto", help="transformers device_map") |
|
|
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"], help="Model dtype") |
|
|
parser.add_argument("--max_new_tokens", type=int, default=512) |
|
|
parser.add_argument("--temperature", type=float, default=0.0) |
|
|
parser.add_argument("--top_p", type=float, default=1.0) |
|
|
parser.add_argument("--do_sample", action="store_true") |
|
|
parser.add_argument("--skip_missing_image", action="store_true", help="Skip samples if image not found; otherwise evaluate with text-only") |
|
|
parser.add_argument("--append", action="store_true", help="Append to output JSONL and stream-save each sample") |
|
|
parser.add_argument("--no_fsync", action="store_true", help="Do not call os.fsync after each write (faster, less durable)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
rank, world_size, local_rank = setup_distributed() |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
os.makedirs(os.path.dirname(args.output_jsonl), exist_ok=True) |
|
|
|
|
|
|
|
|
if world_size > 1: |
|
|
|
|
|
temp_dir = os.path.join(os.path.dirname(args.output_jsonl), ".tmp_ranks") |
|
|
if rank == 0: |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
base = os.path.basename(args.output_jsonl) |
|
|
temp_output_jsonl = os.path.join(temp_dir, f"rank{rank}_{base}") |
|
|
else: |
|
|
temp_output_jsonl = args.output_jsonl |
|
|
|
|
|
|
|
|
data = load_dataset(args.data) |
|
|
if args.max_samples is not None and args.max_samples > 0: |
|
|
data = data[: args.max_samples] |
|
|
|
|
|
|
|
|
if world_size > 1: |
|
|
data = shard_data(data, rank, world_size) |
|
|
if rank == 0: |
|
|
print(f"[Distributed] Total GPUs: {world_size}, Rank {rank} processing {len(data)} samples") |
|
|
|
|
|
|
|
|
if world_size > 1: |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
dtype_map = { |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float16": torch.float16, |
|
|
"float32": torch.float32, |
|
|
} |
|
|
torch_dtype = dtype_map.get(args.dtype, torch.bfloat16) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) |
|
|
model = _load_qwen_vl_model(args.model, torch_dtype=torch_dtype, device_map=args.device_map, local_rank=local_rank) |
|
|
model.eval() |
|
|
|
|
|
if rank == 0: |
|
|
print(f"Model loaded on {world_size} GPU(s)") |
|
|
|
|
|
|
|
|
num_correct = 0 |
|
|
num_total = 0 |
|
|
file_mode = "a" if args.append else "w" |
|
|
|
|
|
|
|
|
with open(temp_output_jsonl, file_mode, encoding="utf-8") as fout: |
|
|
for idx, ex in enumerate(tqdm(data, desc="Evaluating")): |
|
|
prompt: str = ex.get("prompt", "") or "" |
|
|
gt_boxed: str = ex.get("solution", "") or "" |
|
|
image_path: Optional[str] = ex.get("image_path", None) |
|
|
|
|
|
image = open_image(image_path) |
|
|
if image is None and image_path and args.skip_missing_image: |
|
|
|
|
|
continue |
|
|
|
|
|
try: |
|
|
pred_text = generate_answer( |
|
|
model=model, |
|
|
processor=processor, |
|
|
prompt=prompt, |
|
|
image=image, |
|
|
max_new_tokens=args.max_new_tokens, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
do_sample=args.do_sample, |
|
|
) |
|
|
except Exception as e: |
|
|
pred_text = f"[GENERATION_FAILED] {e}" |
|
|
|
|
|
pred_ans = extract_boxed_answer(pred_text) |
|
|
gt_ans = extract_boxed_answer(gt_boxed) if gt_boxed else "" |
|
|
|
|
|
correct = False |
|
|
if pred_ans and gt_ans: |
|
|
correct = normalize_answer(pred_ans) == normalize_answer(gt_ans) |
|
|
|
|
|
num_total += 1 |
|
|
if correct: |
|
|
num_correct += 1 |
|
|
|
|
|
|
|
|
log_item = { |
|
|
"correct": bool(correct), |
|
|
"prediction_extracted": pred_ans, |
|
|
"ground_truth_extracted": gt_ans, |
|
|
"prediction_text": pred_text, |
|
|
"ground_truth": gt_boxed, |
|
|
"id": idx, |
|
|
"prompt": prompt, |
|
|
"image_path": image_path, |
|
|
} |
|
|
fout.write(json.dumps(log_item, ensure_ascii=False) + "\n") |
|
|
|
|
|
fout.flush() |
|
|
if not args.no_fsync: |
|
|
try: |
|
|
os.fsync(fout.fileno()) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if world_size > 1: |
|
|
|
|
|
local_correct = torch.tensor([num_correct], dtype=torch.long, device=f"cuda:{local_rank}") |
|
|
local_total = torch.tensor([num_total], dtype=torch.long, device=f"cuda:{local_rank}") |
|
|
|
|
|
|
|
|
dist.all_reduce(local_correct, op=dist.ReduceOp.SUM) |
|
|
dist.all_reduce(local_total, op=dist.ReduceOp.SUM) |
|
|
|
|
|
global_correct = local_correct.item() |
|
|
global_total = local_total.item() |
|
|
else: |
|
|
global_correct = num_correct |
|
|
global_total = num_total |
|
|
|
|
|
|
|
|
if world_size > 1: |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
|
|
|
if world_size > 1: |
|
|
print(f"\nMerging results from {world_size} ranks into {args.output_jsonl}...") |
|
|
temp_dir = os.path.join(os.path.dirname(args.output_jsonl), ".tmp_ranks") |
|
|
base = os.path.basename(args.output_jsonl) |
|
|
merge_rank_outputs(args.output_jsonl, temp_dir, base, world_size) |
|
|
|
|
|
|
|
|
acc = (global_correct / global_total) if global_total > 0 else 0.0 |
|
|
print("\n" + "="*50) |
|
|
print("Final Results:") |
|
|
print("="*50) |
|
|
print(json.dumps({ |
|
|
"accuracy": acc, |
|
|
"num_correct": global_correct, |
|
|
"num_total": global_total, |
|
|
"data_path": args.data, |
|
|
"model": args.model, |
|
|
"output_jsonl": args.output_jsonl, |
|
|
"world_size": world_size, |
|
|
}, ensure_ascii=False, indent=2)) |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
if world_size > 1: |
|
|
dist.barrier() |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
def merge_rank_outputs(output_path: str, temp_dir: str, base_filename: str, world_size: int): |
|
|
"""Merge output files from all ranks into a single file and cleanup temp files.""" |
|
|
import shutil |
|
|
|
|
|
merged_results = [] |
|
|
|
|
|
|
|
|
for rank in range(world_size): |
|
|
rank_file = os.path.join(temp_dir, f"rank{rank}_{base_filename}") |
|
|
if os.path.exists(rank_file): |
|
|
with open(rank_file, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
merged_results.append(json.loads(line)) |
|
|
else: |
|
|
print(f"Warning: {rank_file} not found") |
|
|
|
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
for item in merged_results: |
|
|
f.write(json.dumps(item, ensure_ascii=False) + "\n") |
|
|
|
|
|
print(f"✓ Merged {len(merged_results)} results into {output_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
if os.path.exists(temp_dir): |
|
|
shutil.rmtree(temp_dir) |
|
|
print(f"✓ Cleaned up temporary files in {temp_dir}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to cleanup temp directory {temp_dir}: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|