CosFly-Track / scripts /eval_vln_transformers.py
Ys404's picture
Add scripts and checkpoints (CosFly-Track release)
6629e44 verified
raw
history blame
7.94 kB
"""
VLN Waypoint Prediction Evaluation — transformers fallback version
For models incompatible with vLLM (e.g., InternVL converted from ModelScope).
"""
import argparse
import json
import os
import re
import time
import logging
from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
DIMS = ["dx", "dy", "dz", "dpitch", "dyaw", "droll"]
NUM_WAYPOINTS = 5
def load_val_data(val_path: str) -> List[Dict]:
data = []
with open(val_path) as f:
for line in f:
data.append(json.loads(line.strip()))
logger.info(f"Loaded {len(data)} validation samples")
return data
def parse_waypoints(text: str) -> Optional[List[Dict]]:
try:
if "</think>" in text:
text = text.split("</think>")[-1]
match = re.search(r'\{.*\}', text, re.DOTALL)
if not match:
return None
obj = json.loads(match.group())
deltas = obj.get("waypoint_deltas", [])
if len(deltas) == 0:
return None
result = []
for d in deltas:
wp = {}
if isinstance(d, dict):
for dim in DIMS:
wp[dim] = float(d.get(dim, 0.0))
elif isinstance(d, (list, tuple)) and len(d) >= len(DIMS):
for i, dim in enumerate(DIMS):
wp[dim] = float(d[i])
else:
return None
result.append(wp)
return result
except (json.JSONDecodeError, ValueError, TypeError, AttributeError, IndexError):
return None
# reuse compute_metrics and print_results from eval_vln_vllm
from eval_vln_vllm import compute_metrics, print_results
def load_internvl_model(model_path, device="cuda"):
from transformers import AutoModelForImageTextToText, AutoTokenizer, AutoProcessor
logger.info(f"Loading InternVL model from {model_path}")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
tokenizer = processor.tokenizer
model = AutoModelForImageTextToText.from_pretrained(
model_path,
dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
).eval()
return model, tokenizer, processor
def load_generic_model(model_path, device="cuda"):
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
logger.info(f"Loading model from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
try:
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
except Exception:
processor = None
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
).eval()
return model, tokenizer, processor
def internvl_generate(model, tokenizer, processor, item, max_new_tokens=512):
messages = item["messages"]
image_paths = item.get("images", [])
images = [Image.open(p).convert("RGB") for p in image_paths]
# NOTE: our training data already includes N "<image>\n" tokens at the
# beginning of every user message (one per image). So we keep the user
# content as a plain string and DO NOT add additional {"type": "image"}
# entries -- that would double-count placeholders.
chat_messages = []
for msg in messages:
if msg["role"] == "assistant":
break
chat_messages.append({"role": msg["role"], "content": msg["content"]})
# Render template -> string with N "<image>" placeholders.
text = processor.apply_chat_template(
chat_messages, add_generation_prompt=True, tokenize=False
)
# InternVL processor expects N "<IMG_CONTEXT>" placeholders (its
# image_token), NOT the chat-template's "<image>" string. Replace
# them 1:1 so the processor can correctly expand each one into
# `image_seq_length * num_patches` IMG_CONTEXT tokens that match
# the corresponding pixel_values slice.
if images:
target_tok = getattr(processor, "image_token", "<IMG_CONTEXT>")
text = text.replace("<image>", target_tok)
inputs = processor(
text=text,
images=images if images else None,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
input_len = inputs["input_ids"].shape[-1]
response = tokenizer.decode(output_ids[0][input_len:], skip_special_tokens=True)
return response
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--val_path", type=str,
default="/mnt/data-a808/R26112/datasets/0318_vln_waypoint_val.jsonl")
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--model_type", type=str, default="internvl",
choices=["internvl", "generic"])
args = parser.parse_args()
model_name = os.path.basename(args.model_path.rstrip("/"))
if args.output_dir is None:
args.output_dir = os.path.dirname(args.model_path.rstrip("/"))
val_data = load_val_data(args.val_path)
if args.max_samples and args.max_samples < len(val_data):
val_data = val_data[:args.max_samples]
if args.model_type == "internvl":
model, tokenizer, processor = load_internvl_model(args.model_path)
else:
model, tokenizer, processor = load_generic_model(args.model_path)
total = len(val_data)
all_errors = []
parse_failures = 0
for idx, item in enumerate(val_data):
gt_text = [m for m in item["messages"] if m["role"] == "assistant"][0]["content"]
gt_wp = parse_waypoints(gt_text)
if gt_wp is None:
continue
try:
response = internvl_generate(model, tokenizer, processor, item)
except Exception as e:
if idx < 3:
import traceback
logger.warning(f"Sample {idx}: generation error: {e}")
logger.warning(traceback.format_exc())
elif idx < 8:
logger.warning(f"Sample {idx}: generation error: {e}")
parse_failures += 1
continue
pred_wp = parse_waypoints(response)
if pred_wp is None:
parse_failures += 1
if parse_failures <= 5 or parse_failures % 100 == 0:
logger.warning(f"Sample {idx}: parse failure. Output: {response[:200]}")
continue
n_wp = min(len(gt_wp), len(pred_wp))
sample_errors = {dim: [] for dim in DIMS}
for wi in range(n_wp):
for dim in DIMS:
err = abs(pred_wp[wi][dim] - gt_wp[wi][dim])
sample_errors[dim].append(err)
all_errors.append(sample_errors)
if (idx + 1) % 50 == 0:
if all_errors:
cur_mae = np.mean([np.mean([e for s in all_errors for e in s[dim]]) for dim in DIMS])
logger.info(f"Progress [{idx+1}/{total}] MAE: {cur_mae:.4f} | parse_fail={parse_failures}")
else:
logger.info(f"Progress [{idx+1}/{total}] | parse_fail={parse_failures}")
results = compute_metrics(all_errors, parse_failures, total)
results["inference_engine"] = "transformers"
print_results(results, model_name)
os.makedirs(args.output_dir, exist_ok=True)
out_file = os.path.join(args.output_dir, f"eval_results_{model_name}.json")
with open(out_file, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {out_file}")
if __name__ == "__main__":
main()