File size: 15,869 Bytes
2e7f2ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 | """
VoRA Evaluation Script
- Perplexity (cross-entropy loss) on held-out caption data
- Caption generation with BLEU / ROUGE-L metrics
Usage:
# Perplexity evaluation
python eval/eval_vora.py --mode perplexity \
--checkpoint output/pretrain_I30M_T6M/checkpoint-250 \
--eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \
--image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c
# Caption generation evaluation
python eval/eval_vora.py --mode caption \
--checkpoint output/pretrain_I30M_T6M/checkpoint-250 \
--eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \
--image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c
# Both
python eval/eval_vora.py --mode all \
--checkpoint output/pretrain_I30M_T6M/checkpoint-250 \
--eval-data data_dir/VoRA-Recap-29M/eval_qwenvl.jsonl \
--image-processor qwen_models/models--apple--aimv2-huge-patch14-448/snapshots/f723839533d3bbdc969f541c864789f531ec0e5c
"""
import argparse
import json
import math
import os
import sys
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoTokenizer
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.modeling_vora import VoRAForCausalLM, VoRAConfig
# ============================================================
# Image preprocessing (same as training pipeline)
# ============================================================
def expand2square(pil_img):
"""Expand image to square with black padding (same as training)."""
background_color = (0, 0, 0)
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def load_and_process_image(image_path, image_processor):
"""Load image, expand to square, apply HF image transforms."""
img = Image.open(image_path).convert("RGB")
img = expand2square(img)
pixel_values = image_processor(img, return_tensors="pt")["pixel_values"] # (1, 3, 448, 448)
return pixel_values
# ============================================================
# Text processing (same prompt template as training)
# ============================================================
IMAGE_TOKEN_INDEX = -200
IGNORE_INDEX = -100
def build_prompt_ids(tokenizer, has_image=True):
"""Build the prompt token IDs (system + user turn) for captioning."""
system_start = "<|im_start|>system\n"
system_message = "You are a helpful assistant."
system_end = "<|im_end|>"
user_start = "\n<|im_start|>user\n"
user_end = "<|im_end|>\n<|im_start|>assistant\n"
if has_image:
# system + user with <image> placeholder
prompt = system_start + system_message + system_end + user_start
prompt_after_image = user_end
prompt_ids = tokenizer.encode(prompt)
after_image_ids = tokenizer.encode(prompt_after_image)
# Insert image token index between prompt and after_image
input_ids = prompt_ids + [IMAGE_TOKEN_INDEX] + after_image_ids
else:
prompt = (system_start + system_message + system_end +
user_start + "Describe this image." + user_end)
input_ids = tokenizer.encode(prompt)
return input_ids
def build_perplexity_batch(tokenizer, image_path, caption, image_processor, device):
"""Build a batch for perplexity evaluation (with labels)."""
prompt_ids = build_prompt_ids(tokenizer, has_image=True)
caption_ids = tokenizer.encode(caption)
eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
full_ids = prompt_ids + caption_ids + [eos_id]
# Labels: -100 for prompt tokens, actual IDs for caption tokens
labels = [IGNORE_INDEX] * len(prompt_ids) + caption_ids + [eos_id]
# Load image
pixel_values = load_and_process_image(image_path, image_processor)
batch = {
"input_ids": torch.tensor([full_ids], dtype=torch.long).to(device),
"attention_mask": torch.ones(1, len(full_ids), dtype=torch.long).to(device),
"labels": torch.tensor([labels], dtype=torch.long).to(device),
"frames": pixel_values.to(device), # (1, 3, 448, 448)
"n_frames": [1],
"vision_placeholder_index": IMAGE_TOKEN_INDEX,
}
return batch, len(caption_ids) + 1 # +1 for eos
def build_generation_batch(tokenizer, image_path, image_processor, device):
"""Build a batch for caption generation (no labels)."""
prompt_ids = build_prompt_ids(tokenizer, has_image=True)
pixel_values = load_and_process_image(image_path, image_processor)
batch = {
"input_ids": torch.tensor([prompt_ids], dtype=torch.long).to(device),
"attention_mask": torch.ones(1, len(prompt_ids), dtype=torch.long).to(device),
"frames": pixel_values.to(device),
"n_frames": [1],
"vision_placeholder_index": IMAGE_TOKEN_INDEX,
}
return batch
# ============================================================
# Load evaluation data
# ============================================================
def load_eval_data(eval_path, max_samples=None):
"""Load eval data from eval_qwenvl.jsonl format: {"image": path, "text": caption}"""
data = []
with open(eval_path, "r") as f:
for line in f:
item = json.loads(line.strip())
data.append(item)
if max_samples and len(data) >= max_samples:
break
print(f"Loaded {len(data)} evaluation samples")
return data
# ============================================================
# Evaluation: Perplexity
# ============================================================
@torch.no_grad()
def evaluate_perplexity(model, tokenizer, image_processor, eval_data, device):
"""Compute perplexity on held-out caption data."""
model.eval()
total_loss = 0.0
total_tokens = 0
errors = 0
for i, item in enumerate(tqdm(eval_data, desc="Perplexity")):
image_path = item["image"]
caption = item["text"]
if not os.path.exists(image_path):
errors += 1
continue
try:
batch, n_caption_tokens = build_perplexity_batch(
tokenizer, image_path, caption, image_processor, device)
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.item() * n_caption_tokens
total_tokens += n_caption_tokens
except Exception as e:
errors += 1
if errors <= 5:
print(f" Error on sample {i}: {e}")
continue
if total_tokens == 0:
print("No valid samples for perplexity!")
return float("inf")
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
print(f"\n=== Perplexity Results ===")
print(f"Samples evaluated: {len(eval_data) - errors}/{len(eval_data)}")
print(f"Errors: {errors}")
print(f"Average cross-entropy loss: {avg_loss:.4f}")
print(f"Perplexity: {perplexity:.2f}")
return perplexity
# ============================================================
# Evaluation: Caption Generation
# ============================================================
@torch.no_grad()
def evaluate_caption(model, tokenizer, image_processor, eval_data, device,
max_new_tokens=256):
"""Generate captions and compute BLEU / ROUGE-L."""
model.eval()
predictions = []
references = []
errors = 0
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
for i, item in enumerate(tqdm(eval_data, desc="Caption Generation")):
image_path = item["image"]
caption = item["text"]
if not os.path.exists(image_path):
errors += 1
continue
try:
batch = build_generation_batch(tokenizer, image_path, image_processor, device)
outputs = model.generate(
batch,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=eos_token_id,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
predictions.append(generated_text)
references.append(caption)
except Exception as e:
errors += 1
if errors <= 5:
print(f" Error on sample {i}: {e}")
continue
if len(predictions) == 0:
print("No valid samples for caption evaluation!")
return {}
# Compute metrics
metrics = compute_caption_metrics(predictions, references)
print(f"\n=== Caption Generation Results ===")
print(f"Samples evaluated: {len(predictions)}/{len(eval_data)}")
print(f"Errors: {errors}")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
# Print a few examples
print(f"\n--- Sample Outputs (first 5) ---")
for i in range(min(5, len(predictions))):
print(f"[{i}] Generated: {predictions[i][:200]}")
print(f"[{i}] Reference: {references[i][:200]}")
print()
return metrics
def compute_caption_metrics(predictions, references):
"""Compute BLEU-1, BLEU-4, ROUGE-L metrics."""
metrics = {}
# BLEU
try:
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
smooth = SmoothingFunction().method1
refs_tokenized = [[ref.split()] for ref in references]
preds_tokenized = [pred.split() for pred in predictions]
metrics["BLEU-1"] = corpus_bleu(refs_tokenized, preds_tokenized,
weights=(1, 0, 0, 0),
smoothing_function=smooth)
metrics["BLEU-4"] = corpus_bleu(refs_tokenized, preds_tokenized,
weights=(0.25, 0.25, 0.25, 0.25),
smoothing_function=smooth)
except ImportError:
print("Warning: nltk not installed, skipping BLEU. Install with: pip install nltk")
# ROUGE-L
try:
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred)["rougeL"].fmeasure
for pred, ref in zip(predictions, references)]
metrics["ROUGE-L"] = sum(rouge_scores) / len(rouge_scores)
except ImportError:
print("Warning: rouge_score not installed, skipping ROUGE-L. Install with: pip install rouge-score")
return metrics
# ============================================================
# Model loading
# ============================================================
def load_vora_model(checkpoint_path, device_map="auto", dtype=torch.float16):
"""Load VoRA model from checkpoint."""
print(f"Loading VoRA model from {checkpoint_path} ...")
config = VoRAConfig.from_pretrained(checkpoint_path)
# Disable aux_vision for inference (not needed)
config.aux_vision = ""
model = VoRAForCausalLM(config)
model.debug_max_steps = 0 # Disable debug prints
# Load checkpoint weights
from tools.merge_lora import partial_load_from_checkpoints
state_dict = partial_load_from_checkpoints(checkpoint_path)
msg = model.load_state_dict(state_dict, strict=False)
print(f"Load state dict: missing={len(msg.missing_keys)}, unexpected={len(msg.unexpected_keys)}")
if msg.missing_keys:
print(f" Missing keys (first 5): {msg.missing_keys[:5]}")
model = model.to(dtype=dtype)
if device_map == "auto" and torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map_computed = infer_auto_device_map(model, max_memory={
i: "22GiB" for i in range(torch.cuda.device_count())
})
model = dispatch_model(model, device_map=device_map_computed)
print(f"Model dispatched across {torch.cuda.device_count()} GPUs")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model on {device}")
model.eval()
return model
def load_merged_vora_model(merged_path, device_map="auto", dtype=torch.float16):
"""Load merged (LoRA-free) VoRA model."""
print(f"Loading merged VoRA model from {merged_path} ...")
model = VoRAForCausalLM.from_pretrained(
merged_path,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
)
model.debug_max_steps = 0
model.eval()
return model
# ============================================================
# Main
# ============================================================
def main():
parser = argparse.ArgumentParser(description="VoRA Evaluation")
parser.add_argument("--mode", type=str, default="all",
choices=["perplexity", "caption", "all"])
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to VoRA checkpoint or merged model directory")
parser.add_argument("--merged", action="store_true",
help="If set, load as merged model (no LoRA)")
parser.add_argument("--eval-data", type=str, required=True,
help="Path to eval_qwenvl.jsonl")
parser.add_argument("--image-processor", type=str, required=True,
help="Path to AIMv2 model for image preprocessing")
parser.add_argument("--max-samples", type=int, default=None,
help="Max number of eval samples (default: all)")
parser.add_argument("--max-new-tokens", type=int, default=256,
help="Max new tokens for caption generation")
parser.add_argument("--dtype", type=str, default="float16",
choices=["float16", "bfloat16"])
parser.add_argument("--output", type=str, default=None,
help="Path to save results JSON")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
# Load model
if args.merged:
model = load_merged_vora_model(args.checkpoint, dtype=dtype)
else:
model = load_vora_model(args.checkpoint, dtype=dtype)
device = next(model.parameters()).device
# Load tokenizer and image processor
tokenizer = model.tokenizer
image_processor = AutoImageProcessor.from_pretrained(args.image_processor)
# Load eval data
eval_data = load_eval_data(args.eval_data, max_samples=args.max_samples)
results = {"checkpoint": args.checkpoint, "num_samples": len(eval_data)}
# Run evaluations
if args.mode in ("perplexity", "all"):
ppl = evaluate_perplexity(model, tokenizer, image_processor, eval_data, device)
results["perplexity"] = ppl
if args.mode in ("caption", "all"):
caption_metrics = evaluate_caption(
model, tokenizer, image_processor, eval_data, device,
max_new_tokens=args.max_new_tokens)
results.update(caption_metrics)
# Save results
if args.output:
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to {args.output}")
return results
if __name__ == "__main__":
main()
|