eval-pack / eval /eval_vora.py
jun-1001's picture
Upload folder using huggingface_hub
2e7f2ce verified
"""
VoRA Evaluation Script
- Perplexity (cross-entropy loss) on held-out caption data
- Caption generation with BLEU / ROUGE-L metrics
Usage:
# Perplexity evaluation
python eval/eval_vora.py --mode perplexity \
--checkpoint output/pretrain_I30M_T6M/checkpoint-250 \
--eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \
--image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c
# Caption generation evaluation
python eval/eval_vora.py --mode caption \
--checkpoint output/pretrain_I30M_T6M/checkpoint-250 \
--eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \
--image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c
# Both
python eval/eval_vora.py --mode all \
--checkpoint output/pretrain_I30M_T6M/checkpoint-250 \
--eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \
--image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c
"""
import argparse
import json
import math
import os
import sys
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoTokenizer
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.modeling_vora import VoRAForCausalLM, VoRAConfig
# ============================================================
# Image preprocessing (same as training pipeline)
# ============================================================
def expand2square(pil_img):
"""Expand image to square with black padding (same as training)."""
background_color = (0, 0, 0)
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def load_and_process_image(image_path, image_processor):
"""Load image, expand to square, apply HF image transforms."""
img = Image.open(image_path).convert("RGB")
img = expand2square(img)
pixel_values = image_processor(img, return_tensors="pt")["pixel_values"] # (1, 3, 448, 448)
return pixel_values
# ============================================================
# Text processing (same prompt template as training)
# ============================================================
IMAGE_TOKEN_INDEX = -200
IGNORE_INDEX = -100
def build_prompt_ids(tokenizer, has_image=True):
"""Build the prompt token IDs (system + user turn) for captioning."""
system_start = "<|im_start|>system\n"
system_message = "You are a helpful assistant."
system_end = "<|im_end|>"
user_start = "\n<|im_start|>user\n"
user_end = "<|im_end|>\n<|im_start|>assistant\n"
if has_image:
# system + user with <image> placeholder
prompt = system_start + system_message + system_end + user_start
prompt_after_image = user_end
prompt_ids = tokenizer.encode(prompt)
after_image_ids = tokenizer.encode(prompt_after_image)
# Insert image token index between prompt and after_image
input_ids = prompt_ids + [IMAGE_TOKEN_INDEX] + after_image_ids
else:
prompt = (system_start + system_message + system_end +
user_start + "Describe this image." + user_end)
input_ids = tokenizer.encode(prompt)
return input_ids
def build_perplexity_batch(tokenizer, image_path, caption, image_processor, device):
"""Build a batch for perplexity evaluation (with labels)."""
prompt_ids = build_prompt_ids(tokenizer, has_image=True)
caption_ids = tokenizer.encode(caption)
eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
full_ids = prompt_ids + caption_ids + [eos_id]
# Labels: -100 for prompt tokens, actual IDs for caption tokens
labels = [IGNORE_INDEX] * len(prompt_ids) + caption_ids + [eos_id]
# Load image
pixel_values = load_and_process_image(image_path, image_processor)
batch = {
"input_ids": torch.tensor([full_ids], dtype=torch.long).to(device),
"attention_mask": torch.ones(1, len(full_ids), dtype=torch.long).to(device),
"labels": torch.tensor([labels], dtype=torch.long).to(device),
"frames": pixel_values.to(device), # (1, 3, 448, 448)
"n_frames": [1],
"vision_placeholder_index": IMAGE_TOKEN_INDEX,
}
return batch, len(caption_ids) + 1 # +1 for eos
def build_generation_batch(tokenizer, image_path, image_processor, device):
"""Build a batch for caption generation (no labels)."""
prompt_ids = build_prompt_ids(tokenizer, has_image=True)
pixel_values = load_and_process_image(image_path, image_processor)
batch = {
"input_ids": torch.tensor([prompt_ids], dtype=torch.long).to(device),
"attention_mask": torch.ones(1, len(prompt_ids), dtype=torch.long).to(device),
"frames": pixel_values.to(device),
"n_frames": [1],
"vision_placeholder_index": IMAGE_TOKEN_INDEX,
}
return batch
# ============================================================
# Load evaluation data
# ============================================================
def load_eval_data(eval_path, max_samples=None):
"""Load eval data from eval_qwenvl.jsonl format: {"image": path, "text": caption}"""
data = []
with open(eval_path, "r") as f:
for line in f:
item = json.loads(line.strip())
data.append(item)
if max_samples and len(data) >= max_samples:
break
print(f"Loaded {len(data)} evaluation samples")
return data
# ============================================================
# Evaluation: Perplexity
# ============================================================
@torch.no_grad()
def evaluate_perplexity(model, tokenizer, image_processor, eval_data, device):
"""Compute perplexity on held-out caption data."""
model.eval()
total_loss = 0.0
total_tokens = 0
errors = 0
for i, item in enumerate(tqdm(eval_data, desc="Perplexity")):
image_path = item["image"]
caption = item["text"]
if not os.path.exists(image_path):
errors += 1
continue
try:
batch, n_caption_tokens = build_perplexity_batch(
tokenizer, image_path, caption, image_processor, device)
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.item() * n_caption_tokens
total_tokens += n_caption_tokens
except Exception as e:
errors += 1
if errors <= 5:
print(f" Error on sample {i}: {e}")
continue
if total_tokens == 0:
print("No valid samples for perplexity!")
return float("inf")
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
print(f"\n=== Perplexity Results ===")
print(f"Samples evaluated: {len(eval_data) - errors}/{len(eval_data)}")
print(f"Errors: {errors}")
print(f"Average cross-entropy loss: {avg_loss:.4f}")
print(f"Perplexity: {perplexity:.2f}")
return perplexity
# ============================================================
# Evaluation: Caption Generation
# ============================================================
@torch.no_grad()
def evaluate_caption(model, tokenizer, image_processor, eval_data, device,
max_new_tokens=256):
"""Generate captions and compute BLEU / ROUGE-L."""
model.eval()
predictions = []
references = []
errors = 0
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
for i, item in enumerate(tqdm(eval_data, desc="Caption Generation")):
image_path = item["image"]
caption = item["text"]
if not os.path.exists(image_path):
errors += 1
continue
try:
batch = build_generation_batch(tokenizer, image_path, image_processor, device)
outputs = model.generate(
batch,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=eos_token_id,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
predictions.append(generated_text)
references.append(caption)
except Exception as e:
errors += 1
if errors <= 5:
print(f" Error on sample {i}: {e}")
continue
if len(predictions) == 0:
print("No valid samples for caption evaluation!")
return {}
# Compute metrics
metrics = compute_caption_metrics(predictions, references)
print(f"\n=== Caption Generation Results ===")
print(f"Samples evaluated: {len(predictions)}/{len(eval_data)}")
print(f"Errors: {errors}")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
# Print a few examples
print(f"\n--- Sample Outputs (first 5) ---")
for i in range(min(5, len(predictions))):
print(f"[{i}] Generated: {predictions[i][:200]}")
print(f"[{i}] Reference: {references[i][:200]}")
print()
return metrics
def compute_caption_metrics(predictions, references):
"""Compute BLEU-1, BLEU-4, ROUGE-L metrics."""
metrics = {}
# BLEU
try:
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
smooth = SmoothingFunction().method1
refs_tokenized = [[ref.split()] for ref in references]
preds_tokenized = [pred.split() for pred in predictions]
metrics["BLEU-1"] = corpus_bleu(refs_tokenized, preds_tokenized,
weights=(1, 0, 0, 0),
smoothing_function=smooth)
metrics["BLEU-4"] = corpus_bleu(refs_tokenized, preds_tokenized,
weights=(0.25, 0.25, 0.25, 0.25),
smoothing_function=smooth)
except ImportError:
print("Warning: nltk not installed, skipping BLEU. Install with: pip install nltk")
# ROUGE-L
try:
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred)["rougeL"].fmeasure
for pred, ref in zip(predictions, references)]
metrics["ROUGE-L"] = sum(rouge_scores) / len(rouge_scores)
except ImportError:
print("Warning: rouge_score not installed, skipping ROUGE-L. Install with: pip install rouge-score")
return metrics
# ============================================================
# Model loading
# ============================================================
def load_vora_model(checkpoint_path, device_map="auto", dtype=torch.float16):
"""Load VoRA model from checkpoint."""
print(f"Loading VoRA model from {checkpoint_path} ...")
config = VoRAConfig.from_pretrained(checkpoint_path)
# Disable aux_vision for inference (not needed)
config.aux_vision = ""
model = VoRAForCausalLM(config)
model.debug_max_steps = 0 # Disable debug prints
# Load checkpoint weights
from tools.merge_lora import partial_load_from_checkpoints
state_dict = partial_load_from_checkpoints(checkpoint_path)
msg = model.load_state_dict(state_dict, strict=False)
print(f"Load state dict: missing={len(msg.missing_keys)}, unexpected={len(msg.unexpected_keys)}")
if msg.missing_keys:
print(f" Missing keys (first 5): {msg.missing_keys[:5]}")
model = model.to(dtype=dtype)
if device_map == "auto" and torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map_computed = infer_auto_device_map(model, max_memory={
i: "22GiB" for i in range(torch.cuda.device_count())
})
model = dispatch_model(model, device_map=device_map_computed)
print(f"Model dispatched across {torch.cuda.device_count()} GPUs")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model on {device}")
model.eval()
return model
def load_merged_vora_model(merged_path, device_map="auto", dtype=torch.float16):
"""Load merged (LoRA-free) VoRA model."""
print(f"Loading merged VoRA model from {merged_path} ...")
model = VoRAForCausalLM.from_pretrained(
merged_path,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
)
model.debug_max_steps = 0
model.eval()
return model
# ============================================================
# Main
# ============================================================
def main():
parser = argparse.ArgumentParser(description="VoRA Evaluation")
parser.add_argument("--mode", type=str, default="all",
choices=["perplexity", "caption", "all"])
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to VoRA checkpoint or merged model directory")
parser.add_argument("--merged", action="store_true",
help="If set, load as merged model (no LoRA)")
parser.add_argument("--eval-data", type=str, required=True,
help="Path to eval_qwenvl.jsonl")
parser.add_argument("--image-processor", type=str, required=True,
help="Path to AIMv2 model for image preprocessing")
parser.add_argument("--max-samples", type=int, default=None,
help="Max number of eval samples (default: all)")
parser.add_argument("--max-new-tokens", type=int, default=256,
help="Max new tokens for caption generation")
parser.add_argument("--dtype", type=str, default="float16",
choices=["float16", "bfloat16"])
parser.add_argument("--output", type=str, default=None,
help="Path to save results JSON")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
# Load model
if args.merged:
model = load_merged_vora_model(args.checkpoint, dtype=dtype)
else:
model = load_vora_model(args.checkpoint, dtype=dtype)
device = next(model.parameters()).device
# Load tokenizer and image processor
tokenizer = model.tokenizer
image_processor = AutoImageProcessor.from_pretrained(args.image_processor)
# Load eval data
eval_data = load_eval_data(args.eval_data, max_samples=args.max_samples)
results = {"checkpoint": args.checkpoint, "num_samples": len(eval_data)}
# Run evaluations
if args.mode in ("perplexity", "all"):
ppl = evaluate_perplexity(model, tokenizer, image_processor, eval_data, device)
results["perplexity"] = ppl
if args.mode in ("caption", "all"):
caption_metrics = evaluate_caption(
model, tokenizer, image_processor, eval_data, device,
max_new_tokens=args.max_new_tokens)
results.update(caption_metrics)
# Save results
if args.output:
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to {args.output}")
return results
if __name__ == "__main__":
main()