File size: 1,924 Bytes
dd9b4af
 
 
bd89217
 
 
 
 
 
 
 
dd9b4af
 
 
 
 
 
 
 
bd89217
 
 
 
 
 
 
dd9b4af
 
bd89217
 
 
 
 
 
 
 
 
 
 
dd9b4af
bd89217
 
dd9b4af
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
Standalone inference script for OpenVLA-Micro.
Usage:
    # GPU inference (HF hub)
    python inference.py --image demo.jpg "pick up the red block"

    # From a local .pt file
    python inference.py --checkpoint openvla-micro-distill.pt --image demo.jpg "pick up the red block"

    # CPU inference
    python inference.py --device cpu --image demo.jpg "pick up the red block"
"""
import argparse
from PIL import Image
from modeling_openvla_micro import OpenVLAMicro


def main():
    parser = argparse.ArgumentParser(description="OpenVLA-Micro inference")
    parser.add_argument("--checkpoint", type=str, default="theguy21/openvla-micro",
                        help="HF repo ID or path to local .pt checkpoint")
    parser.add_argument("--image", type=str, required=True, help="Input image path")
    parser.add_argument("--device", type=str, default="auto",
                        help="Device: auto, cuda, or cpu")
    parser.add_argument("instruction", type=str, nargs="?", default="pick up the red block",
                        help="Task instruction (positional, optional)")
    args = parser.parse_args()

    device = args.device
    if device == "auto":
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"

    llm_kwargs = {}
    if device == "cpu":
        llm_kwargs["torch_dtype"] = "float32"

    print(f"Loading OpenVLA-Micro from {args.checkpoint} on {device}...")
    model = OpenVLAMicro.from_pretrained(args.checkpoint, device=device, llm_kwargs=llm_kwargs)
    model.eval()
    n_params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"Model loaded ({n_params:.0f}M params)")

    image = Image.open(args.image).convert("RGB")
    print(f"Image: {image.size}")
    print(f"Instruction: {args.instruction}")
    action = model.predict_action(image, args.instruction)
    print(f"Action (7-DoF): {action}")


if __name__ == "__main__":
    main()