egosocial-env / scripts /eval_reason2.py
robertzty's picture
Upload folder using huggingface_hub
a0a453b verified
"""Evaluate a Reason2-style checkpoint on the fixed EgoNormia heldout split."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Sequence
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
from train_grpo_reason2 import (
DEFAULT_VERIFIED_SPLIT_PATH,
SYSTEM_PROMPT,
_as_float,
_batch_decode,
_build_user_prompt,
_load_scene_id_set,
_parse_action,
_tokenize_prompt,
_tokenizer_from_processor,
)
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from egosocial_env.server.egosocial_env_environment import (
ENV_MODE_BENCHMARK,
ENV_MODE_TRAIN,
EgosocialEnvironment,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Evaluate a Reason2-style checkpoint on the fixed EgoNormia heldout split.",
)
parser.add_argument(
"--model-id",
required=True,
help="HF model id or local checkpoint path.",
)
parser.add_argument(
"--split-path",
default=str(DEFAULT_VERIFIED_SPLIT_PATH),
help="JSON file with {'split': [...]} scene ids to evaluate.",
)
parser.add_argument(
"--output-path",
default="outputs/eval_reason2_verified.json",
help="Where to write the evaluation summary and per-scene results.",
)
parser.add_argument(
"--env-mode",
choices=[ENV_MODE_BENCHMARK, ENV_MODE_TRAIN],
default=ENV_MODE_BENCHMARK,
)
parser.add_argument(
"--world-model-provider",
default=None,
help="Optional world-model provider for train mode, such as 'cosmos'.",
)
parser.add_argument(
"--max-samples",
type=int,
default=-1,
help="Evaluate only the first N scene ids from the split.",
)
parser.add_argument(
"--max-completion-length",
type=int,
default=320,
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Use 0.0 for deterministic evaluation.",
)
parser.add_argument(
"--top-p",
type=float,
default=1.0,
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
)
parser.add_argument(
"--bf16",
action="store_true",
)
return parser.parse_args()
def _load_eval_stack() -> Dict[str, Any]:
try:
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
except ImportError as exc: # pragma: no cover
raise ImportError(
"Evaluation dependencies are missing. Install them with "
"`UV_CACHE_DIR=/tmp/uv-cache uv sync --extra train`."
) from exc
extra_model_classes: List[Any] = []
try:
from transformers import AutoModelForImageTextToText
extra_model_classes.append(AutoModelForImageTextToText)
except ImportError:
pass
try:
from transformers import AutoModelForVision2Seq
extra_model_classes.append(AutoModelForVision2Seq)
except ImportError:
pass
return {
"torch": torch,
"Image": Image,
"AutoProcessor": AutoProcessor,
"AutoTokenizer": AutoTokenizer,
"model_classes": [AutoModelForCausalLM, *extra_model_classes],
}
def _load_model_and_processor(args: argparse.Namespace) -> tuple[Any, Any]:
stack = _load_eval_stack()
torch = stack["torch"]
AutoProcessor = stack["AutoProcessor"]
AutoTokenizer = stack["AutoTokenizer"]
model_classes = stack["model_classes"]
try:
processor = AutoProcessor.from_pretrained(
args.model_id,
trust_remote_code=args.trust_remote_code,
)
except Exception:
processor = AutoTokenizer.from_pretrained(
args.model_id,
trust_remote_code=args.trust_remote_code,
)
dtype = torch.bfloat16 if args.bf16 else None
load_errors = []
model = None
for model_cls in model_classes:
try:
model = model_cls.from_pretrained(
args.model_id,
trust_remote_code=args.trust_remote_code,
torch_dtype=dtype,
device_map="auto",
)
break
except Exception as exc: # pragma: no cover
load_errors.append(f"{model_cls.__name__}: {exc}")
if model is None: # pragma: no cover
raise RuntimeError(
"Failed to load model. Tried: " + " | ".join(load_errors)
)
model.eval()
return model, processor
def _move_inputs_to_device(inputs: Dict[str, Any], device: Any) -> Dict[str, Any]:
moved = {}
for key, value in inputs.items():
if hasattr(value, "to"):
moved[key] = value.to(device)
else:
moved[key] = value
return moved
def _existing_frame_paths(observation: Any) -> List[str]:
return [path for path in observation.frame_paths if Path(path).exists()]
def _build_multimodal_messages(
observation: Any,
image_paths: Sequence[str],
) -> List[Dict[str, Any]]:
user_prompt = _build_user_prompt(observation)
if not image_paths:
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
user_content: List[Dict[str, str]] = [
{"type": "image", "image": str(path)}
for path in image_paths
]
user_content.append({"type": "text", "text": user_prompt})
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
def _apply_chat_template(processor: Any, messages: List[Dict[str, Any]]) -> str:
if hasattr(processor, "apply_chat_template"):
return processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
tokenizer = getattr(processor, "tokenizer", processor)
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
raise ValueError("Processor/tokenizer does not support apply_chat_template.")
def _load_images(image_paths: Sequence[str]) -> List[Any]:
stack = _load_eval_stack()
Image = stack["Image"]
images = []
for path in image_paths:
with Image.open(path) as image:
images.append(image.convert("RGB").copy())
return images
def _generate_completion(
model: Any,
processor: Any,
observation: Any,
*,
max_completion_length: int,
temperature: float,
top_p: float,
) -> tuple[str, str]:
stack = _load_eval_stack()
torch = stack["torch"]
device = next(model.parameters()).device
image_paths = _existing_frame_paths(observation)
input_mode = "text"
if image_paths and hasattr(processor, "image_processor"):
messages = _build_multimodal_messages(observation, image_paths)
prompt_text = _apply_chat_template(processor, messages)
inputs = processor(
text=[prompt_text],
images=_load_images(image_paths),
return_tensors="pt",
padding=True,
)
input_mode = "image"
inputs = _move_inputs_to_device(inputs, device)
else:
messages = _build_multimodal_messages(observation, [])
prompt_text = _apply_chat_template(processor, messages)
inputs = _tokenize_prompt(processor, prompt_text, device)
input_ids = inputs["input_ids"]
input_length = int(input_ids.shape[1])
do_sample = temperature > 0.0
tokenizer = _tokenizer_from_processor(processor)
with torch.no_grad():
sequences = model.generate(
**inputs,
max_new_tokens=max_completion_length,
do_sample=do_sample,
temperature=max(temperature, 1e-5),
top_p=top_p,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
completion_ids = sequences[0][input_length:]
return _batch_decode(processor, completion_ids.tolist()), input_mode
def _evaluate_scene(
model: Any,
processor: Any,
env: EgosocialEnvironment,
*,
scene_id: str,
args: argparse.Namespace,
) -> Dict[str, Any]:
observation = env.reset(scene_id=scene_id, mode=args.env_mode)
turns: List[Dict[str, Any]] = []
while not observation.done:
completion_text, input_mode = _generate_completion(
model,
processor,
observation,
max_completion_length=args.max_completion_length,
temperature=args.temperature,
top_p=args.top_p,
)
action, format_score = _parse_action(completion_text, observation)
turns.append(
{
"phase": observation.phase,
"raw_completion": completion_text,
"action": action.model_dump(),
"format_reward": format_score,
"input_mode": input_mode,
"frame_paths": list(observation.frame_paths),
}
)
observation = env.step(action)
metadata = observation.metadata or {}
reward_breakdown = metadata.get("reward_breakdown", {})
return {
"scene_id": scene_id,
"reward": _as_float(observation.reward),
"correct": bool(metadata.get("correct")),
"transition_source": metadata.get(
"transition_source",
observation.transition_source,
),
"reward_breakdown": reward_breakdown,
"rubric_average": _as_float(metadata.get("rubric_average")),
"turns": turns,
}
def _mean(values: Iterable[float]) -> float:
numbers = list(values)
if not numbers:
return 0.0
return round(sum(numbers) / len(numbers), 4)
def _safe_output_path(path_str: str) -> Path:
path = Path(path_str)
if not path.is_absolute():
path = Path.cwd() / path
path.parent.mkdir(parents=True, exist_ok=True)
return path
def _resolved_path_str(path_str: str) -> str:
path = Path(path_str)
if not path.is_absolute():
path = Path.cwd() / path
return str(path.resolve())
def _take_scene_ids(args: argparse.Namespace) -> List[str]:
scene_ids = sorted(_load_scene_id_set(args.split_path))
if args.max_samples > 0:
scene_ids = scene_ids[: args.max_samples]
return scene_ids
def _summary(results: Sequence[Dict[str, Any]], args: argparse.Namespace) -> Dict[str, Any]:
return {
"model_id": args.model_id,
"env_mode": args.env_mode,
"world_model_provider": args.world_model_provider,
"split_path": _resolved_path_str(args.split_path),
"num_samples": len(results),
"avg_reward": _mean(result["reward"] for result in results),
"accuracy": _mean(1.0 if result["correct"] else 0.0 for result in results),
"avg_action_selection": _mean(
_as_float(result["reward_breakdown"].get("action_selection"))
for result in results
),
"avg_sensibility": _mean(
_as_float(result["reward_breakdown"].get("sensibility"))
for result in results
),
"avg_taxonomy_match": _mean(
_as_float(result["reward_breakdown"].get("taxonomy_match"))
for result in results
),
"avg_justification_alignment": _mean(
_as_float(result["reward_breakdown"].get("justification_alignment"))
for result in results
),
"avg_rubric_average": _mean(result["rubric_average"] for result in results),
}
def main() -> None:
args = parse_args()
model, processor = _load_model_and_processor(args)
env = EgosocialEnvironment(world_model_provider=args.world_model_provider)
requested_scene_ids = _take_scene_ids(args)
if not requested_scene_ids:
raise ValueError("No scene ids found in the evaluation split.")
available_scene_ids = {episode["scene_id"] for episode in env._episodes} # noqa: SLF001
missing_scene_ids = [
scene_id for scene_id in requested_scene_ids if scene_id not in available_scene_ids
]
scene_ids = [
scene_id for scene_id in requested_scene_ids if scene_id in available_scene_ids
]
if not scene_ids:
raise ValueError("No evaluation scenes overlap with the local EgoNormia env.")
if missing_scene_ids:
print(
"Warning: skipping "
f"{len(missing_scene_ids)} scene ids that are not present in the local env."
)
preview = ", ".join(missing_scene_ids[:10])
print(f"Missing scene ids: {preview}")
results = []
for index, scene_id in enumerate(scene_ids, start=1):
result = _evaluate_scene(
model,
processor,
env,
scene_id=scene_id,
args=args,
)
results.append(result)
print(
f"[{index}/{len(scene_ids)}] {scene_id} "
f"reward={result['reward']:.3f} correct={int(result['correct'])}"
)
payload = {
"summary": {
**_summary(results, args),
"requested_num_samples": len(requested_scene_ids),
"available_num_samples": len(scene_ids),
"missing_num_samples": len(missing_scene_ids),
},
"missing_scene_ids": missing_scene_ids,
"results": results,
}
output_path = _safe_output_path(args.output_path)
output_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
print(json.dumps(payload["summary"], indent=2))
print(f"Saved evaluation results to {output_path}")
if __name__ == "__main__":
main()