| | from typing import Dict, Any |
| | import torch |
| | from lerobot.common.policies.policy import Pi0Policy |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir: str, **kwargs: Any) -> None: |
| | """Load the Pi0 model from the specified directory.""" |
| | self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu") |
| | self.policy.eval() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Any: |
| | """Handle incoming requests and run inference.""" |
| | |
| | images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) |
| | state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0) |
| | prompt = [data["inputs"]["prompt"]] |
| |
|
| | |
| | batch = { |
| | "observation.images": images, |
| | "observation.state": state, |
| | "prompt": prompt, |
| | } |
| |
|
| | |
| | with torch.no_grad(): |
| | actions = self.policy.select_action(batch) |
| |
|
| | return {"actions": actions[0].tolist()} |
| | from typing import Dict, Any |
| | import torch |
| | from lerobot.common.policies.policy import Pi0Policy |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir: str, **kwargs: Any) -> None: |
| | """Load the Pi0 model from the specified directory.""" |
| | self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu") |
| | self.policy.eval() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Any: |
| | """Handle incoming requests and run inference.""" |
| | |
| | images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) |
| | state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0) |
| | prompt = [data["inputs"]["prompt"]] |
| |
|
| | |
| | batch = { |
| | "observation.images": images, |
| | "observation.state": state, |
| | "prompt": prompt, |
| | } |
| |
|
| | |
| | with torch.no_grad(): |
| | actions = self.policy.select_action(batch) |
| |
|
| | return {"actions": actions[0].tolist()} |
| |
|