openvla-micro / inference_cpu.py
theguy21's picture
Add CPU inference script, update README with model details and perf stats
bd89217 verified
Raw
History Blame Contribute Delete
2.4 kB
"""
Edge device / CPU inference for OpenVLA-Micro.
This script is optimized for resource-constrained environments.
Two modes:
1. Standard CPU – float32, ~3-5 sec/step on modern x86, ~6GB RAM
2. Low-RAM (4-bit) – uses bitsandbytes 4-bit quantization, ~2.5GB RAM,
slightly slower but usable on 4GB devices like RPi 5
with sufficient swap.
Usage:
python inference_cpu.py --image demo.jpg "pick up the red block"
python inference_cpu.py --low-ram --image demo.jpg "pick up the red block"
python inference_cpu.py --checkpoint ./openvla-micro-distill.pt --image demo.jpg "pick up the red block"
"""
import argparse
from PIL import Image
from modeling_openvla_micro import OpenVLAMicro
def main():
parser = argparse.ArgumentParser(description="OpenVLA-Micro CPU/edge inference")
parser.add_argument("--checkpoint", type=str, default="theguy21/openvla-micro",
help="HF repo ID or local .pt path")
parser.add_argument("--image", type=str, required=True, help="Input image path")
parser.add_argument("--low-ram", action="store_true",
help="4-bit quantized LLM (~2.5GB peak, requires bitsandbytes)")
parser.add_argument("instruction", type=str, nargs="?", default="pick up the red block",
help="Task instruction (positional, optional)")
args = parser.parse_args()
device = "cpu"
llm_kwargs = {}
if args.low_ram:
print("Low-RAM mode: 4-bit quantization (requires bitsandbytes)")
llm_kwargs = {
"load_in_4bit": True,
"bnb_4bit_compute_dtype": "float32",
"bnb_4bit_use_double_quant": True,
}
else:
print("Standard CPU mode: float32 (~6GB RAM)")
llm_kwargs["torch_dtype"] = "float32"
print(f"Loading OpenVLA-Micro from {args.checkpoint} on CPU...")
model = OpenVLAMicro.from_pretrained(args.checkpoint, device=device, llm_kwargs=llm_kwargs)
model.eval()
n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Model loaded ({n_params:.0f}M params)")
image = Image.open(args.image).convert("RGB")
print(f"Image: {image.size}")
print(f"Instruction: {args.instruction}")
action = model.predict_action(image, args.instruction)
print(f"Action (7-DoF): {action}")
if __name__ == "__main__":
main()