| """ |
| deploy.py |
| |
| Starts VLA server which the client can query to get robot actions. |
| """ |
|
|
| import os.path |
|
|
| |
| import json_numpy |
|
|
| json_numpy.patch() |
| import json |
| import logging |
| import numpy as np |
| import traceback |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Union |
|
|
| import draccus |
| import torch |
| import uvicorn |
| from fastapi import FastAPI |
| from fastapi.responses import JSONResponse |
| from PIL import Image |
| from transformers import AutoModelForVision2Seq, AutoProcessor |
|
|
| from experiments.robot.openvla_utils import ( |
| get_vla, |
| get_vla_action, |
| get_action_head, |
| get_processor, |
| get_proprio_projector, |
| ) |
| from experiments.robot.robot_utils import ( |
| get_image_resize_size, |
| ) |
| from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX |
|
|
|
|
| def get_openvla_prompt(instruction: str, openvla_path: Union[str, Path]) -> str: |
| return f"In: What action should the robot take to {instruction.lower()}?\nOut:" |
|
|
|
|
| |
| class OpenVLAServer: |
| def __init__(self, cfg) -> Path: |
| """ |
| A simple server for OpenVLA models; exposes `/act` to predict an action for a given observation + instruction. |
| """ |
| self.cfg = cfg |
|
|
| |
| self.vla = get_vla(cfg) |
|
|
| |
| self.proprio_projector = None |
| if cfg.use_proprio: |
| self.proprio_projector = get_proprio_projector(cfg, self.vla.llm_dim, PROPRIO_DIM) |
|
|
| |
| self.action_head = None |
| if cfg.use_l1_regression or cfg.use_diffusion: |
| self.action_head = get_action_head(cfg, self.vla.llm_dim) |
|
|
| |
| assert cfg.unnorm_key in self.vla.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!" |
|
|
| |
| self.processor = None |
| self.processor = get_processor(cfg) |
|
|
| |
| self.resize_size = get_image_resize_size(cfg) |
|
|
|
|
| def get_server_action(self, payload: Dict[str, Any]) -> str: |
| try: |
| if double_encode := "encoded" in payload: |
| |
| assert len(payload.keys()) == 1, "Only uses encoded payload!" |
| payload = json.loads(payload["encoded"]) |
|
|
| observation = payload |
| instruction = observation["instruction"] |
|
|
| action = get_vla_action( |
| self.cfg, self.vla, self.processor, observation, instruction, action_head=self.action_head, proprio_projector=self.proprio_projector, use_film=self.cfg.use_film, |
| ) |
|
|
| if double_encode: |
| return JSONResponse(json_numpy.dumps(action)) |
| else: |
| return JSONResponse(action) |
| except: |
| logging.error(traceback.format_exc()) |
| logging.warning( |
| "Your request threw an error; make sure your request complies with the expected format:\n" |
| "{'observation': dict, 'instruction': str}\n" |
| ) |
| return "error" |
|
|
| def run(self, host: str = "0.0.0.0", port: int = 8777) -> None: |
| self.app = FastAPI() |
| self.app.post("/act")(self.get_server_action) |
| uvicorn.run(self.app, host=host, port=port) |
|
|
|
|
| @dataclass |
| class DeployConfig: |
| |
|
|
| |
| host: str = "0.0.0.0" |
| port: int = 8777 |
|
|
| |
| |
| |
| model_family: str = "openvla" |
| pretrained_checkpoint: Union[str, Path] = "" |
|
|
| use_l1_regression: bool = True |
| use_diffusion: bool = False |
| num_diffusion_steps_train: int = 50 |
| num_diffusion_steps_inference: int = 50 |
| use_film: bool = False |
| num_images_in_input: int = 3 |
| use_proprio: bool = True |
|
|
| center_crop: bool = True |
|
|
| lora_rank: int = 32 |
|
|
| unnorm_key: Union[str, Path] = "" |
| use_relative_actions: bool = False |
|
|
| load_in_8bit: bool = False |
| load_in_4bit: bool = False |
|
|
| |
| |
| |
| seed: int = 7 |
| |
|
|
|
|
| @draccus.wrap() |
| def deploy(cfg: DeployConfig) -> None: |
| server = OpenVLAServer(cfg) |
| server.run(cfg.host, port=cfg.port) |
|
|
|
|
| if __name__ == "__main__": |
| deploy() |
|
|