TTI / Dev /verl /utils /openvla_utils.py
JosephBai's picture
Upload folder using huggingface_hub
857c2e9 verified
"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies."""
import filecmp
import json
import os
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import json_numpy
import numpy as np
import tensorflow as tf
import torch
from PIL import Image
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
# Apply JSON numpy patch for serialization
json_numpy.patch()
# Configure NumPy print settings
np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
def update_auto_map(pretrained_checkpoint: str) -> None:
"""
Update the AutoMap configuration in the checkpoint config.json file.
This loads the config.json file inside the checkpoint directory and overwrites
the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes.
Args:
pretrained_checkpoint: Path to the checkpoint directory
"""
if not os.path.isdir(pretrained_checkpoint):
return
config_path = os.path.join(pretrained_checkpoint, "config.json")
if not os.path.exists(config_path):
print(f"Warning: No config.json found at {config_path}")
return
# Create timestamped backup
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = os.path.join(pretrained_checkpoint, f"config.json.back.{timestamp}")
shutil.copy2(config_path, backup_path)
print(f"Created backup of original config at: {os.path.abspath(backup_path)}")
# Read and update the config
with open(config_path, "r") as f:
config = json.load(f)
config["auto_map"] = {
"AutoConfig": "configuration_prismatic.OpenVLAConfig",
"AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction",
}
# Write back the updated config
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Updated config.json at: {os.path.abspath(config_path)}")
print("Changes made:")
print(' - Set AutoConfig to "configuration_prismatic.OpenVLAConfig"')
print(' - Set AutoModelForVision2Seq to "modeling_prismatic.OpenVLAForActionPrediction"')
def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool:
"""
Check if two files are identical in content.
Args:
path1: Path to the first file
path2: Path to the second file
Returns:
bool: True if files are identical, False otherwise
"""
path1, path2 = Path(path1), Path(path2)
# First check if file sizes match
if path1.stat().st_size != path2.stat().st_size:
return False
# Check if contents match
return filecmp.cmp(path1, path2, shallow=False)
def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None:
"""
Handle syncing of files between current directory and checkpoint.
Creates backups if files exist but differ, and copies current versions to checkpoint.
Args:
curr_filepath: Path to the current file version
checkpoint_filepath: Path where the file should be in the checkpoint
file_type: Description of the file type for logging
"""
if os.path.exists(checkpoint_filepath):
# Check if existing files are identical
match = check_identical_files(curr_filepath, checkpoint_filepath)
if not match:
print(
"\n------------------------------------------------------------------------------------------------\n"
f"Found mismatch between:\n"
f"Current: {curr_filepath}\n"
f"Checkpoint: {checkpoint_filepath}\n"
)
# Create timestamped backup
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{checkpoint_filepath}.back.{timestamp}"
shutil.copy2(checkpoint_filepath, backup_path)
print(f"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}")
# Copy current version to checkpoint directory
shutil.copy2(curr_filepath, checkpoint_filepath)
print(f"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}")
print(
f"Changes complete. The checkpoint will now use the current version of {file_type}"
"\n------------------------------------------------------------------------------------------------\n"
)
else:
# If file doesn't exist in checkpoint directory, copy it
shutil.copy2(curr_filepath, checkpoint_filepath)
print(
"\n------------------------------------------------------------------------------------------------\n"
f"No {file_type} found in checkpoint directory.\n"
f"Copied current version from: {curr_filepath}\n"
f"To checkpoint location: {os.path.abspath(checkpoint_filepath)}"
"\n------------------------------------------------------------------------------------------------\n"
)
def check_model_logic_mismatch(pretrained_checkpoint: str) -> None:
"""
Check and sync model logic files between current code and checkpoint.
Handles the relationship between current and checkpoint versions of both
modeling_prismatic.py and configuration_prismatic.py:
- If checkpoint file exists and differs: creates backup and copies current version
- If checkpoint file doesn't exist: copies current version
Args:
pretrained_checkpoint: Path to the checkpoint directory
"""
if not os.path.isdir(pretrained_checkpoint):
return
# Find current files
curr_files = {"modeling_prismatic.py": None, "configuration_prismatic.py": None}
for root, _, files in os.walk("./prismatic/"):
for filename in curr_files.keys():
if filename in files and curr_files[filename] is None:
curr_files[filename] = os.path.join(root, filename)
# Check and handle each file
for filename, curr_filepath in curr_files.items():
if curr_filepath is None:
print(f"WARNING: `{filename}` is not found anywhere in the current directory.")
continue
checkpoint_filepath = os.path.join(pretrained_checkpoint, filename)
_handle_file_sync(curr_filepath, checkpoint_filepath, filename)