""" VoRA Evaluation Script - Perplexity (cross-entropy loss) on held-out caption data - Caption generation with BLEU / ROUGE-L metrics Usage: # Perplexity evaluation python eval/eval_vora.py --mode perplexity \ --checkpoint output/pretrain_I30M_T6M/checkpoint-250 \ --eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \ --image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c # Caption generation evaluation python eval/eval_vora.py --mode caption \ --checkpoint output/pretrain_I30M_T6M/checkpoint-250 \ --eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \ --image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c # Both python eval/eval_vora.py --mode all \ --checkpoint output/pretrain_I30M_T6M/checkpoint-250 \ --eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \ --image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c """ import argparse import json import math import os import sys import torch import torch.nn.functional as F from PIL import Image from tqdm import tqdm from transformers import AutoImageProcessor, AutoTokenizer # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.modeling_vora import VoRAForCausalLM, VoRAConfig # ============================================================ # Image preprocessing (same as training pipeline) # ============================================================ def expand2square(pil_img): """Expand image to square with black padding (same as training).""" background_color = (0, 0, 0) width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def load_and_process_image(image_path, image_processor): """Load image, expand to square, apply HF image transforms.""" img = Image.open(image_path).convert("RGB") img = expand2square(img) pixel_values = image_processor(img, return_tensors="pt")["pixel_values"] # (1, 3, 448, 448) return pixel_values # ============================================================ # Text processing (same prompt template as training) # ============================================================ IMAGE_TOKEN_INDEX = -200 IGNORE_INDEX = -100 def build_prompt_ids(tokenizer, has_image=True): """Build the prompt token IDs (system + user turn) for captioning.""" system_start = "<|im_start|>system\n" system_message = "You are a helpful assistant." system_end = "<|im_end|>" user_start = "\n<|im_start|>user\n" user_end = "<|im_end|>\n<|im_start|>assistant\n" if has_image: # system + user with placeholder prompt = system_start + system_message + system_end + user_start prompt_after_image = user_end prompt_ids = tokenizer.encode(prompt) after_image_ids = tokenizer.encode(prompt_after_image) # Insert image token index between prompt and after_image input_ids = prompt_ids + [IMAGE_TOKEN_INDEX] + after_image_ids else: prompt = (system_start + system_message + system_end + user_start + "Describe this image." + user_end) input_ids = tokenizer.encode(prompt) return input_ids def build_perplexity_batch(tokenizer, image_path, caption, image_processor, device): """Build a batch for perplexity evaluation (with labels).""" prompt_ids = build_prompt_ids(tokenizer, has_image=True) caption_ids = tokenizer.encode(caption) eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>") full_ids = prompt_ids + caption_ids + [eos_id] # Labels: -100 for prompt tokens, actual IDs for caption tokens labels = [IGNORE_INDEX] * len(prompt_ids) + caption_ids + [eos_id] # Load image pixel_values = load_and_process_image(image_path, image_processor) batch = { "input_ids": torch.tensor([full_ids], dtype=torch.long).to(device), "attention_mask": torch.ones(1, len(full_ids), dtype=torch.long).to(device), "labels": torch.tensor([labels], dtype=torch.long).to(device), "frames": pixel_values.to(device), # (1, 3, 448, 448) "n_frames": [1], "vision_placeholder_index": IMAGE_TOKEN_INDEX, } return batch, len(caption_ids) + 1 # +1 for eos def build_generation_batch(tokenizer, image_path, image_processor, device): """Build a batch for caption generation (no labels).""" prompt_ids = build_prompt_ids(tokenizer, has_image=True) pixel_values = load_and_process_image(image_path, image_processor) batch = { "input_ids": torch.tensor([prompt_ids], dtype=torch.long).to(device), "attention_mask": torch.ones(1, len(prompt_ids), dtype=torch.long).to(device), "frames": pixel_values.to(device), "n_frames": [1], "vision_placeholder_index": IMAGE_TOKEN_INDEX, } return batch # ============================================================ # Load evaluation data # ============================================================ def load_eval_data(eval_path, max_samples=None): """Load eval data from eval_qwenvl.jsonl format: {"image": path, "text": caption}""" data = [] with open(eval_path, "r") as f: for line in f: item = json.loads(line.strip()) data.append(item) if max_samples and len(data) >= max_samples: break print(f"Loaded {len(data)} evaluation samples") return data # ============================================================ # Evaluation: Perplexity # ============================================================ @torch.no_grad() def evaluate_perplexity(model, tokenizer, image_processor, eval_data, device): """Compute perplexity on held-out caption data.""" model.eval() total_loss = 0.0 total_tokens = 0 errors = 0 for i, item in enumerate(tqdm(eval_data, desc="Perplexity")): image_path = item["image"] caption = item["text"] if not os.path.exists(image_path): errors += 1 continue try: batch, n_caption_tokens = build_perplexity_batch( tokenizer, image_path, caption, image_processor, device) outputs = model(**batch) loss = outputs.loss total_loss += loss.item() * n_caption_tokens total_tokens += n_caption_tokens except Exception as e: errors += 1 if errors <= 5: print(f" Error on sample {i}: {e}") continue if total_tokens == 0: print("No valid samples for perplexity!") return float("inf") avg_loss = total_loss / total_tokens perplexity = math.exp(avg_loss) print(f"\n=== Perplexity Results ===") print(f"Samples evaluated: {len(eval_data) - errors}/{len(eval_data)}") print(f"Errors: {errors}") print(f"Average cross-entropy loss: {avg_loss:.4f}") print(f"Perplexity: {perplexity:.2f}") return perplexity # ============================================================ # Evaluation: Caption Generation # ============================================================ @torch.no_grad() def evaluate_caption(model, tokenizer, image_processor, eval_data, device, max_new_tokens=256): """Generate captions and compute BLEU / ROUGE-L.""" model.eval() predictions = [] references = [] errors = 0 eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") for i, item in enumerate(tqdm(eval_data, desc="Caption Generation")): image_path = item["image"] caption = item["text"] if not os.path.exists(image_path): errors += 1 continue try: batch = build_generation_batch(tokenizer, image_path, image_processor, device) outputs = model.generate( batch, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id, eos_token_id=eos_token_id, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) predictions.append(generated_text) references.append(caption) except Exception as e: errors += 1 if errors <= 5: print(f" Error on sample {i}: {e}") continue if len(predictions) == 0: print("No valid samples for caption evaluation!") return {} # Compute metrics metrics = compute_caption_metrics(predictions, references) print(f"\n=== Caption Generation Results ===") print(f"Samples evaluated: {len(predictions)}/{len(eval_data)}") print(f"Errors: {errors}") for k, v in metrics.items(): print(f"{k}: {v:.4f}") # Print a few examples print(f"\n--- Sample Outputs (first 5) ---") for i in range(min(5, len(predictions))): print(f"[{i}] Generated: {predictions[i][:200]}") print(f"[{i}] Reference: {references[i][:200]}") print() return metrics def compute_caption_metrics(predictions, references): """Compute BLEU-1, BLEU-4, ROUGE-L metrics.""" metrics = {} # BLEU try: from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction smooth = SmoothingFunction().method1 refs_tokenized = [[ref.split()] for ref in references] preds_tokenized = [pred.split() for pred in predictions] metrics["BLEU-1"] = corpus_bleu(refs_tokenized, preds_tokenized, weights=(1, 0, 0, 0), smoothing_function=smooth) metrics["BLEU-4"] = corpus_bleu(refs_tokenized, preds_tokenized, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth) except ImportError: print("Warning: nltk not installed, skipping BLEU. Install with: pip install nltk") # ROUGE-L try: from rouge_score import rouge_scorer scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) rouge_scores = [scorer.score(ref, pred)["rougeL"].fmeasure for pred, ref in zip(predictions, references)] metrics["ROUGE-L"] = sum(rouge_scores) / len(rouge_scores) except ImportError: print("Warning: rouge_score not installed, skipping ROUGE-L. Install with: pip install rouge-score") return metrics # ============================================================ # Model loading # ============================================================ def load_vora_model(checkpoint_path, device_map="auto", dtype=torch.float16): """Load VoRA model from checkpoint.""" print(f"Loading VoRA model from {checkpoint_path} ...") config = VoRAConfig.from_pretrained(checkpoint_path) # Disable aux_vision for inference (not needed) config.aux_vision = "" model = VoRAForCausalLM(config) model.debug_max_steps = 0 # Disable debug prints # Load checkpoint weights from tools.merge_lora import partial_load_from_checkpoints state_dict = partial_load_from_checkpoints(checkpoint_path) msg = model.load_state_dict(state_dict, strict=False) print(f"Load state dict: missing={len(msg.missing_keys)}, unexpected={len(msg.unexpected_keys)}") if msg.missing_keys: print(f" Missing keys (first 5): {msg.missing_keys[:5]}") model = model.to(dtype=dtype) if device_map == "auto" and torch.cuda.device_count() > 1: from accelerate import dispatch_model, infer_auto_device_map device_map_computed = infer_auto_device_map(model, max_memory={ i: "22GiB" for i in range(torch.cuda.device_count()) }) model = dispatch_model(model, device_map=device_map_computed) print(f"Model dispatched across {torch.cuda.device_count()} GPUs") else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(f"Model on {device}") model.eval() return model def load_merged_vora_model(merged_path, device_map="auto", dtype=torch.float16): """Load merged (LoRA-free) VoRA model.""" print(f"Loading merged VoRA model from {merged_path} ...") model = VoRAForCausalLM.from_pretrained( merged_path, torch_dtype=dtype, device_map=device_map, trust_remote_code=True, ) model.debug_max_steps = 0 model.eval() return model # ============================================================ # Main # ============================================================ def main(): parser = argparse.ArgumentParser(description="VoRA Evaluation") parser.add_argument("--mode", type=str, default="all", choices=["perplexity", "caption", "all"]) parser.add_argument("--checkpoint", type=str, required=True, help="Path to VoRA checkpoint or merged model directory") parser.add_argument("--merged", action="store_true", help="If set, load as merged model (no LoRA)") parser.add_argument("--eval-data", type=str, required=True, help="Path to eval_qwenvl.jsonl") parser.add_argument("--image-processor", type=str, required=True, help="Path to AIMv2 model for image preprocessing") parser.add_argument("--max-samples", type=int, default=None, help="Max number of eval samples (default: all)") parser.add_argument("--max-new-tokens", type=int, default=256, help="Max new tokens for caption generation") parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16"]) parser.add_argument("--output", type=str, default=None, help="Path to save results JSON") args = parser.parse_args() dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 # Load model if args.merged: model = load_merged_vora_model(args.checkpoint, dtype=dtype) else: model = load_vora_model(args.checkpoint, dtype=dtype) device = next(model.parameters()).device # Load tokenizer and image processor tokenizer = model.tokenizer image_processor = AutoImageProcessor.from_pretrained(args.image_processor) # Load eval data eval_data = load_eval_data(args.eval_data, max_samples=args.max_samples) results = {"checkpoint": args.checkpoint, "num_samples": len(eval_data)} # Run evaluations if args.mode in ("perplexity", "all"): ppl = evaluate_perplexity(model, tokenizer, image_processor, eval_data, device) results["perplexity"] = ppl if args.mode in ("caption", "all"): caption_metrics = evaluate_caption( model, tokenizer, image_processor, eval_data, device, max_new_tokens=args.max_new_tokens) results.update(caption_metrics) # Save results if args.output: os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) with open(args.output, "w") as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"\nResults saved to {args.output}") return results if __name__ == "__main__": main()