from typing import Dict, Any import torch from lerobot.common.policies.policy import Pi0Policy class EndpointHandler: def __init__(self, model_dir: str, **kwargs: Any) -> None: """Load the Pi0 model from the specified directory.""" self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu") self.policy.eval() def __call__(self, data: Dict[str, Any]) -> Any: """Handle incoming requests and run inference.""" # Extract inputs from request images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) # Batch dimension state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0) prompt = [data["inputs"]["prompt"]] # Prepare batch for inference batch = { "observation.images": images, "observation.state": state, "prompt": prompt, } # Run inference with torch.no_grad(): actions = self.policy.select_action(batch) return {"actions": actions[0].tolist()} # Return first batch item from typing import Dict, Any import torch from lerobot.common.policies.policy import Pi0Policy class EndpointHandler: def __init__(self, model_dir: str, **kwargs: Any) -> None: """Load the Pi0 model from the specified directory.""" self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu") self.policy.eval() def __call__(self, data: Dict[str, Any]) -> Any: """Handle incoming requests and run inference.""" # Extract inputs from request images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0) # Batch dimension state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0) prompt = [data["inputs"]["prompt"]] # Prepare batch for inference batch = { "observation.images": images, "observation.state": state, "prompt": prompt, } # Run inference with torch.no_grad(): actions = self.policy.select_action(batch) return {"actions": actions[0].tolist()} # Return first batch item