File size: 2,284 Bytes
d7aadf2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | from typing import Dict, Any
import torch
from lerobot.common.policies.policy import Pi0Policy
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any) -> None:
"""Load the Pi0 model from the specified directory."""
self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu")
self.policy.eval()
def __call__(self, data: Dict[str, Any]) -> Any:
"""Handle incoming requests and run inference."""
# Extract inputs from request
images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) # Batch dimension
state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0)
prompt = [data["inputs"]["prompt"]]
# Prepare batch for inference
batch = {
"observation.images": images,
"observation.state": state,
"prompt": prompt,
}
# Run inference
with torch.no_grad():
actions = self.policy.select_action(batch)
return {"actions": actions[0].tolist()} # Return first batch item
from typing import Dict, Any
import torch
from lerobot.common.policies.policy import Pi0Policy
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any) -> None:
"""Load the Pi0 model from the specified directory."""
self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu")
self.policy.eval()
def __call__(self, data: Dict[str, Any]) -> Any:
"""Handle incoming requests and run inference."""
# Extract inputs from request
images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) # Batch dimension
state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0)
prompt = [data["inputs"]["prompt"]]
# Prepare batch for inference
batch = {
"observation.images": images,
"observation.state": state,
"prompt": prompt,
}
# Run inference
with torch.no_grad():
actions = self.policy.select_action(batch)
return {"actions": actions[0].tolist()} # Return first batch item
|