servertest_pi0 / handler.py
joaoocruz00's picture
Upload 2 files
d7aadf2 verified
from typing import Dict, Any
import torch
from lerobot.common.policies.policy import Pi0Policy
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any) -> None:
"""Load the Pi0 model from the specified directory."""
self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu")
self.policy.eval()
def __call__(self, data: Dict[str, Any]) -> Any:
"""Handle incoming requests and run inference."""
# Extract inputs from request
images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) # Batch dimension
state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0)
prompt = [data["inputs"]["prompt"]]
# Prepare batch for inference
batch = {
"observation.images": images,
"observation.state": state,
"prompt": prompt,
}
# Run inference
with torch.no_grad():
actions = self.policy.select_action(batch)
return {"actions": actions[0].tolist()} # Return first batch item
from typing import Dict, Any
import torch
from lerobot.common.policies.policy import Pi0Policy
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any) -> None:
"""Load the Pi0 model from the specified directory."""
self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu")
self.policy.eval()
def __call__(self, data: Dict[str, Any]) -> Any:
"""Handle incoming requests and run inference."""
# Extract inputs from request
images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) # Batch dimension
state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0)
prompt = [data["inputs"]["prompt"]]
# Prepare batch for inference
batch = {
"observation.images": images,
"observation.state": state,
"prompt": prompt,
}
# Run inference
with torch.no_grad():
actions = self.policy.select_action(batch)
return {"actions": actions[0].tolist()} # Return first batch item