| """Utils for evaluating OpenVLA or fine-tuned OpenVLA policies.""" |
|
|
| import filecmp |
| import json |
| import os |
| import shutil |
| import time |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import json_numpy |
| import numpy as np |
| import tensorflow as tf |
| import torch |
| from PIL import Image |
| from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor |
|
|
| |
| json_numpy.patch() |
|
|
| |
| np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) |
|
|
|
|
| def update_auto_map(pretrained_checkpoint: str) -> None: |
| """ |
| Update the AutoMap configuration in the checkpoint config.json file. |
| |
| This loads the config.json file inside the checkpoint directory and overwrites |
| the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes. |
| |
| Args: |
| pretrained_checkpoint: Path to the checkpoint directory |
| """ |
| if not os.path.isdir(pretrained_checkpoint): |
| return |
|
|
| config_path = os.path.join(pretrained_checkpoint, "config.json") |
| if not os.path.exists(config_path): |
| print(f"Warning: No config.json found at {config_path}") |
| return |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| backup_path = os.path.join(pretrained_checkpoint, f"config.json.back.{timestamp}") |
| shutil.copy2(config_path, backup_path) |
| print(f"Created backup of original config at: {os.path.abspath(backup_path)}") |
|
|
| |
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| config["auto_map"] = { |
| "AutoConfig": "configuration_prismatic.OpenVLAConfig", |
| "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction", |
| } |
|
|
| |
| with open(config_path, "w") as f: |
| json.dump(config, f, indent=2) |
|
|
| print(f"Updated config.json at: {os.path.abspath(config_path)}") |
| print("Changes made:") |
| print(' - Set AutoConfig to "configuration_prismatic.OpenVLAConfig"') |
| print(' - Set AutoModelForVision2Seq to "modeling_prismatic.OpenVLAForActionPrediction"') |
|
|
|
|
| def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool: |
| """ |
| Check if two files are identical in content. |
| |
| Args: |
| path1: Path to the first file |
| path2: Path to the second file |
| |
| Returns: |
| bool: True if files are identical, False otherwise |
| """ |
| path1, path2 = Path(path1), Path(path2) |
|
|
| |
| if path1.stat().st_size != path2.stat().st_size: |
| return False |
|
|
| |
| return filecmp.cmp(path1, path2, shallow=False) |
|
|
|
|
| def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None: |
| """ |
| Handle syncing of files between current directory and checkpoint. |
| |
| Creates backups if files exist but differ, and copies current versions to checkpoint. |
| |
| Args: |
| curr_filepath: Path to the current file version |
| checkpoint_filepath: Path where the file should be in the checkpoint |
| file_type: Description of the file type for logging |
| """ |
| if os.path.exists(checkpoint_filepath): |
| |
| match = check_identical_files(curr_filepath, checkpoint_filepath) |
|
|
| if not match: |
| print( |
| "\n------------------------------------------------------------------------------------------------\n" |
| f"Found mismatch between:\n" |
| f"Current: {curr_filepath}\n" |
| f"Checkpoint: {checkpoint_filepath}\n" |
| ) |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| backup_path = f"{checkpoint_filepath}.back.{timestamp}" |
| shutil.copy2(checkpoint_filepath, backup_path) |
| print(f"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}") |
|
|
| |
| shutil.copy2(curr_filepath, checkpoint_filepath) |
| print(f"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}") |
| print( |
| f"Changes complete. The checkpoint will now use the current version of {file_type}" |
| "\n------------------------------------------------------------------------------------------------\n" |
| ) |
| else: |
| |
| shutil.copy2(curr_filepath, checkpoint_filepath) |
| print( |
| "\n------------------------------------------------------------------------------------------------\n" |
| f"No {file_type} found in checkpoint directory.\n" |
| f"Copied current version from: {curr_filepath}\n" |
| f"To checkpoint location: {os.path.abspath(checkpoint_filepath)}" |
| "\n------------------------------------------------------------------------------------------------\n" |
| ) |
|
|
|
|
| def check_model_logic_mismatch(pretrained_checkpoint: str) -> None: |
| """ |
| Check and sync model logic files between current code and checkpoint. |
| |
| Handles the relationship between current and checkpoint versions of both |
| modeling_prismatic.py and configuration_prismatic.py: |
| - If checkpoint file exists and differs: creates backup and copies current version |
| - If checkpoint file doesn't exist: copies current version |
| |
| Args: |
| pretrained_checkpoint: Path to the checkpoint directory |
| """ |
| if not os.path.isdir(pretrained_checkpoint): |
| return |
|
|
| |
| curr_files = {"modeling_prismatic.py": None, "configuration_prismatic.py": None} |
|
|
| for root, _, files in os.walk("./prismatic/"): |
| for filename in curr_files.keys(): |
| if filename in files and curr_files[filename] is None: |
| curr_files[filename] = os.path.join(root, filename) |
|
|
| |
| for filename, curr_filepath in curr_files.items(): |
| if curr_filepath is None: |
| print(f"WARNING: `{filename}` is not found anywhere in the current directory.") |
| continue |
|
|
| checkpoint_filepath = os.path.join(pretrained_checkpoint, filename) |
| _handle_file_sync(curr_filepath, checkpoint_filepath, filename) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|