File size: 7,940 Bytes
34bc2eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """
VLN Waypoint Prediction Evaluation — transformers fallback version
For models incompatible with vLLM (e.g., InternVL converted from ModelScope).
"""
import argparse
import json
import os
import re
import time
import logging
from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
DIMS = ["dx", "dy", "dz", "dpitch", "dyaw", "droll"]
NUM_WAYPOINTS = 5
def load_val_data(val_path: str) -> List[Dict]:
data = []
with open(val_path) as f:
for line in f:
data.append(json.loads(line.strip()))
logger.info(f"Loaded {len(data)} validation samples")
return data
def parse_waypoints(text: str) -> Optional[List[Dict]]:
try:
if "</think>" in text:
text = text.split("</think>")[-1]
match = re.search(r'\{.*\}', text, re.DOTALL)
if not match:
return None
obj = json.loads(match.group())
deltas = obj.get("waypoint_deltas", [])
if len(deltas) == 0:
return None
result = []
for d in deltas:
wp = {}
if isinstance(d, dict):
for dim in DIMS:
wp[dim] = float(d.get(dim, 0.0))
elif isinstance(d, (list, tuple)) and len(d) >= len(DIMS):
for i, dim in enumerate(DIMS):
wp[dim] = float(d[i])
else:
return None
result.append(wp)
return result
except (json.JSONDecodeError, ValueError, TypeError, AttributeError, IndexError):
return None
# reuse compute_metrics and print_results from eval_vln_vllm
from eval_vln_vllm import compute_metrics, print_results
def load_internvl_model(model_path, device="cuda"):
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor
logger.info(f"Loading InternVL model from {model_path}")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
tokenizer = processor.tokenizer
model = AutoModelForImageTextToText.from_pretrained(
model_path,
dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
).eval()
return model, tokenizer, processor
def load_generic_model(model_path, device="cuda"):
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
logger.info(f"Loading model from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
try:
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
except Exception:
processor = None
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
).eval()
return model, tokenizer, processor
def internvl_generate(model, tokenizer, processor, item, max_new_tokens=512):
messages = item["messages"]
image_paths = item.get("images", [])
images = [Image.open(p).convert("RGB") for p in image_paths]
# NOTE: our training data already includes N "<image>\n" tokens at the
# beginning of every user message (one per image). So we keep the user
# content as a plain string and DO NOT add additional {"type": "image"}
# entries -- that would double-count placeholders.
chat_messages = []
for msg in messages:
if msg["role"] == "assistant":
break
chat_messages.append({"role": msg["role"], "content": msg["content"]})
# Render template -> string with N "<image>" placeholders.
text = processor.apply_chat_template(
chat_messages, add_generation_prompt=True, tokenize=False
)
# InternVL processor expects N "<IMG_CONTEXT>" placeholders (its
# image_token), NOT the chat-template's "<image>" string. Replace
# them 1:1 so the processor can correctly expand each one into
# `image_seq_length * num_patches` IMG_CONTEXT tokens that match
# the corresponding pixel_values slice.
if images:
target_tok = getattr(processor, "image_token", "<IMG_CONTEXT>")
text = text.replace("<image>", target_tok)
inputs = processor(
text=text,
images=images if images else None,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
input_len = inputs["input_ids"].shape[-1]
response = tokenizer.decode(output_ids[0][input_len:], skip_special_tokens=True)
return response
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--val_path", type=str,
default="/mnt/data-a808/R26112/datasets/0318_vln_waypoint_val.jsonl")
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--model_type", type=str, default="internvl",
choices=["internvl", "generic"])
args = parser.parse_args()
model_name = os.path.basename(args.model_path.rstrip("/"))
if args.output_dir is None:
args.output_dir = os.path.dirname(args.model_path.rstrip("/"))
val_data = load_val_data(args.val_path)
if args.max_samples and args.max_samples < len(val_data):
val_data = val_data[:args.max_samples]
if args.model_type == "internvl":
model, tokenizer, processor = load_internvl_model(args.model_path)
else:
model, tokenizer, processor = load_generic_model(args.model_path)
total = len(val_data)
all_errors = []
parse_failures = 0
for idx, item in enumerate(val_data):
gt_text = [m for m in item["messages"] if m["role"] == "assistant"][0]["content"]
gt_wp = parse_waypoints(gt_text)
if gt_wp is None:
continue
try:
response = internvl_generate(model, tokenizer, processor, item)
except Exception as e:
if idx < 3:
import traceback
logger.warning(f"Sample {idx}: generation error: {e}")
logger.warning(traceback.format_exc())
elif idx < 8:
logger.warning(f"Sample {idx}: generation error: {e}")
parse_failures += 1
continue
pred_wp = parse_waypoints(response)
if pred_wp is None:
parse_failures += 1
if parse_failures <= 5 or parse_failures % 100 == 0:
logger.warning(f"Sample {idx}: parse failure. Output: {response[:200]}")
continue
n_wp = min(len(gt_wp), len(pred_wp))
sample_errors = {dim: [] for dim in DIMS}
for wi in range(n_wp):
for dim in DIMS:
err = abs(pred_wp[wi][dim] - gt_wp[wi][dim])
sample_errors[dim].append(err)
all_errors.append(sample_errors)
if (idx + 1) % 50 == 0:
if all_errors:
cur_mae = np.mean([np.mean([e for s in all_errors for e in s[dim]]) for dim in DIMS])
logger.info(f"Progress [{idx+1}/{total}] MAE: {cur_mae:.4f} | parse_fail={parse_failures}")
else:
logger.info(f"Progress [{idx+1}/{total}] | parse_fail={parse_failures}")
results = compute_metrics(all_errors, parse_failures, total)
results["inference_engine"] = "transformers"
print_results(results, model_name)
os.makedirs(args.output_dir, exist_ok=True)
out_file = os.path.join(args.output_dir, f"eval_results_{model_name}.json")
with open(out_file, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {out_file}")
if __name__ == "__main__":
main()
|