File size: 2,284 Bytes
d7aadf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Dict, Any
import torch
from lerobot.common.policies.policy import Pi0Policy

class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs: Any) -> None:
        """Load the Pi0 model from the specified directory."""
        self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu")
        self.policy.eval()

    def __call__(self, data: Dict[str, Any]) -> Any:
        """Handle incoming requests and run inference."""
        # Extract inputs from request
        images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0)  # Batch dimension
        state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0)
        prompt = [data["inputs"]["prompt"]]

        # Prepare batch for inference
        batch = {
            "observation.images": images,
            "observation.state": state,
            "prompt": prompt,
        }

        # Run inference
        with torch.no_grad():
            actions = self.policy.select_action(batch)

        return {"actions": actions[0].tolist()}  # Return first batch item
from typing import Dict, Any
import torch
from lerobot.common.policies.policy import Pi0Policy

class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs: Any) -> None:
        """Load the Pi0 model from the specified directory."""
        self.policy = Pi0Policy.from_pretrained(model_dir, device="cuda" if torch.cuda.is_available() else "cpu")
        self.policy.eval()

    def __call__(self, data: Dict[str, Any]) -> Any:
        """Handle incoming requests and run inference."""
        # Extract inputs from request
        images = torch.tensor(data["inputs"]["images"], dtype=torch.float32).unsqueeze(0)  # Batch dimension
        state = torch.tensor(data["inputs"]["state"], dtype=torch.float32).unsqueeze(0)
        prompt = [data["inputs"]["prompt"]]

        # Prepare batch for inference
        batch = {
            "observation.images": images,
            "observation.state": state,
            "prompt": prompt,
        }

        # Run inference
        with torch.no_grad():
            actions = self.policy.select_action(batch)

        return {"actions": actions[0].tolist()}  # Return first batch item